This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new d062c899fb GH-36103: [C++] Initial device sync API (#37040)
d062c899fb is described below

commit d062c899fbdaa7ce8b67ebe657c4bb76bf918d94
Author: Matt Topol <[email protected]>
AuthorDate: Tue Aug 22 15:08:37 2023 -0400

    GH-36103: [C++] Initial device sync API (#37040)
    
    
    
    ### Rationale for this change
    Building on the `ArrowDeviceArray` we need to expand the abstractions for 
handling events and stream synchronization for devices.
    
    ### What changes are included in this PR?
    Initial Abstract implementations for the new DeviceSync API and a CPU 
implementation. This will be followed up by a CUDA implementation in a 
subsequent PR.
    
    ### Are these changes tested?
    Yes, tests are added for Import/Export DeviceArrays using the DeviceSync 
handling.
    
    * Closes: #36103
    
    Lead-authored-by: Matt Topol <[email protected]>
    Co-authored-by: Benjamin Kietzman <[email protected]>
    Co-authored-by: Antoine Pitrou <[email protected]>
    Signed-off-by: Matt Topol <[email protected]>
---
 cpp/src/arrow/buffer.h            |   2 +
 cpp/src/arrow/c/bridge.cc         |  44 +++---
 cpp/src/arrow/c/bridge.h          |  29 ++--
 cpp/src/arrow/c/bridge_test.cc    | 278 +++++++++++++++++++++++++++++++++-----
 cpp/src/arrow/device.cc           |   9 ++
 cpp/src/arrow/device.h            |  66 +++++++++
 cpp/src/arrow/gpu/cuda_context.cc |   7 +
 cpp/src/arrow/gpu/cuda_context.h  |   3 +
 8 files changed, 365 insertions(+), 73 deletions(-)

diff --git a/cpp/src/arrow/buffer.h b/cpp/src/arrow/buffer.h
index 08a3bd749e..7cc2d2c9cc 100644
--- a/cpp/src/arrow/buffer.h
+++ b/cpp/src/arrow/buffer.h
@@ -346,6 +346,8 @@ class ARROW_EXPORT Buffer {
   static Result<std::shared_ptr<Buffer>> ViewOrCopy(
       std::shared_ptr<Buffer> source, const std::shared_ptr<MemoryManager>& 
to);
 
+  virtual std::shared_ptr<Device::SyncEvent> device_sync_event() { return 
NULLPTR; }
+
  protected:
   bool is_mutable_;
   bool is_cpu_;
diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc
index 13355dd6d0..b967af28e4 100644
--- a/cpp/src/arrow/c/bridge.cc
+++ b/cpp/src/arrow/c/bridge.cc
@@ -521,8 +521,7 @@ struct ExportedArrayPrivateData : 
PoolAllocationMixin<ExportedArrayPrivateData>
   SmallVector<struct ArrowArray*, 4> child_pointers_;
 
   std::shared_ptr<ArrayData> data_;
-
-  RawSyncEvent sync_event_;
+  std::shared_ptr<Device::SyncEvent> sync_;
 
   ExportedArrayPrivateData() = default;
   ARROW_DEFAULT_MOVE_AND_ASSIGN(ExportedArrayPrivateData);
@@ -547,10 +546,6 @@ void ReleaseExportedArray(struct ArrowArray* array) {
   }
   DCHECK_NE(array->private_data, nullptr);
   auto* pdata = 
reinterpret_cast<ExportedArrayPrivateData*>(array->private_data);
-  if (pdata->sync_event_.sync_event != nullptr &&
-      pdata->sync_event_.release_func != nullptr) {
-    pdata->sync_event_.release_func(pdata->sync_event_.sync_event);
-  }
   delete pdata;
 
   ArrowArrayMarkReleased(array);
@@ -591,7 +586,7 @@ struct ArrayExporter {
     // Store owning pointer to ArrayData
     export_.data_ = data;
 
-    export_.sync_event_ = RawSyncEvent();
+    export_.sync_ = nullptr;
     return Status::OK();
   }
 
@@ -714,12 +709,9 @@ Result<std::pair<std::optional<DeviceAllocationType>, 
int64_t>> ValidateDeviceIn
   return std::make_pair(device_type, device_id);
 }
 
-Status ExportDeviceArray(const Array& array, RawSyncEvent sync_event,
+Status ExportDeviceArray(const Array& array, 
std::shared_ptr<Device::SyncEvent> sync,
                          struct ArrowDeviceArray* out, struct ArrowSchema* 
out_schema) {
-  if (sync_event.sync_event != nullptr && sync_event.release_func) {
-    return Status::Invalid(
-        "Must provide a release event function if providing a non-null event");
-  }
+  void* sync_event = sync ? sync->get_raw() : nullptr;
 
   SchemaExportGuard guard(out_schema);
   if (out_schema != nullptr) {
@@ -739,19 +731,20 @@ Status ExportDeviceArray(const Array& array, RawSyncEvent 
sync_event,
   exporter.Finish(&out->array);
 
   auto* pdata = 
reinterpret_cast<ExportedArrayPrivateData*>(out->array.private_data);
-  pdata->sync_event_ = sync_event;
-  out->sync_event = sync_event.sync_event;
+  pdata->sync_ = std::move(sync);
+  out->sync_event = sync_event;
 
   guard.Detach();
   return Status::OK();
 }
 
-Status ExportDeviceRecordBatch(const RecordBatch& batch, RawSyncEvent 
sync_event,
+Status ExportDeviceRecordBatch(const RecordBatch& batch,
+                               std::shared_ptr<Device::SyncEvent> sync,
                                struct ArrowDeviceArray* out,
                                struct ArrowSchema* out_schema) {
-  if (sync_event.sync_event != nullptr && sync_event.release_func == nullptr) {
-    return Status::Invalid(
-        "Must provide a release event function if providing a non-null event");
+  void* sync_event{nullptr};
+  if (sync) {
+    sync_event = sync->get_raw();
   }
 
   // XXX perhaps bypass ToStructArray for speed?
@@ -776,8 +769,8 @@ Status ExportDeviceRecordBatch(const RecordBatch& batch, 
RawSyncEvent sync_event
   exporter.Finish(&out->array);
 
   auto* pdata = 
reinterpret_cast<ExportedArrayPrivateData*>(out->array.private_data);
-  pdata->sync_event_ = sync_event;
-  out->sync_event = sync_event.sync_event;
+  pdata->sync_ = std::move(sync);
+  out->sync_event = sync_event;
 
   guard.Detach();
   return Status::OK();
@@ -1362,7 +1355,7 @@ namespace {
 // The ArrowArray is released on destruction.
 struct ImportedArrayData {
   struct ArrowArray array_;
-  void* sync_event_;
+  std::shared_ptr<Device::SyncEvent> device_sync_;
 
   ImportedArrayData() {
     ArrowArrayMarkReleased(&array_);  // Initially released
@@ -1395,6 +1388,10 @@ class ImportedBuffer : public Buffer {
 
   ~ImportedBuffer() override {}
 
+  std::shared_ptr<Device::SyncEvent> device_sync_event() override {
+    return import_->device_sync_;
+  }
+
  protected:
   std::shared_ptr<ImportedArrayData> import_;
 };
@@ -1409,7 +1406,10 @@ struct ArrayImporter {
     ARROW_ASSIGN_OR_RAISE(memory_mgr_, mapper(src->device_type, 
src->device_id));
     device_type_ = static_cast<DeviceAllocationType>(src->device_type);
     RETURN_NOT_OK(Import(&src->array));
-    import_->sync_event_ = src->sync_event;
+    if (src->sync_event != nullptr) {
+      ARROW_ASSIGN_OR_RAISE(import_->device_sync_, 
memory_mgr_->WrapDeviceSyncEvent(
+                                                       src->sync_event, 
[](void*) {}));
+    }
     // reset internal state before next import
     memory_mgr_.reset();
     device_type_ = DeviceAllocationType::kCPU;
diff --git a/cpp/src/arrow/c/bridge.h b/cpp/src/arrow/c/bridge.h
index 92707a5972..45583109a7 100644
--- a/cpp/src/arrow/c/bridge.h
+++ b/cpp/src/arrow/c/bridge.h
@@ -22,6 +22,7 @@
 #include <string>
 
 #include "arrow/c/abi.h"
+#include "arrow/device.h"
 #include "arrow/result.h"
 #include "arrow/status.h"
 #include "arrow/type_fwd.h"
@@ -172,17 +173,6 @@ Result<std::shared_ptr<RecordBatch>> 
ImportRecordBatch(struct ArrowArray* array,
 ///
 /// @{
 
-/// \brief EXPERIMENTAL: Type for freeing a sync event
-///
-/// If synchronization is necessary for accessing the data on a device,
-/// a pointer to an event needs to be passed when exporting the device
-/// array. It's the responsibility of the release function for the array
-/// to release the event. Both can be null if no sync'ing is necessary.
-struct RawSyncEvent {
-  void* sync_event = NULL;
-  std::function<void(void*)> release_func;
-};
-
 /// \brief EXPERIMENTAL: Export C++ Array as an ArrowDeviceArray.
 ///
 /// The resulting ArrowDeviceArray struct keeps the array data and buffers 
alive
@@ -190,15 +180,15 @@ struct RawSyncEvent {
 /// the provided array MUST have the same device_type, otherwise an error
 /// will be returned.
 ///
-/// If a non-null sync_event is provided, then the sync_release func must also 
be
-/// non-null. If the sync_event is null, then the sync_release parameter is 
not called.
+/// If sync is non-null, get_event will be called on it in order to
+/// potentially provide an event for consumers to synchronize on.
 ///
 /// \param[in] array Array object to export
-/// \param[in] sync_event A struct containing what is needed for syncing if 
necessary
+/// \param[in] sync shared_ptr to object derived from Device::SyncEvent or null
 /// \param[out] out C struct to export the array to
 /// \param[out] out_schema optional C struct to export the array type to
 ARROW_EXPORT
-Status ExportDeviceArray(const Array& array, RawSyncEvent sync_event,
+Status ExportDeviceArray(const Array& array, 
std::shared_ptr<Device::SyncEvent> sync,
                          struct ArrowDeviceArray* out,
                          struct ArrowSchema* out_schema = NULLPTR);
 
@@ -212,15 +202,16 @@ Status ExportDeviceArray(const Array& array, RawSyncEvent 
sync_event,
 /// otherwise an error will be returned. If columns are on different devices,
 /// they should be exported using different ArrowDeviceArray instances.
 ///
-/// If a non-null sync_event is provided, then the sync_release func must also 
be
-/// non-null. If the sync_event is null, then the sync_release parameter is 
ignored.
+/// If sync is non-null, get_event will be called on it in order to
+/// potentially provide an event for consumers to synchronize on.
 ///
 /// \param[in] batch Record batch to export
-/// \param[in] sync_event A struct containing what is needed for syncing if 
necessary
+/// \param[in] sync shared_ptr to object derived from Device::SyncEvent or null
 /// \param[out] out C struct where to export the record batch
 /// \param[out] out_schema optional C struct where to export the record batch 
schema
 ARROW_EXPORT
-Status ExportDeviceRecordBatch(const RecordBatch& batch, RawSyncEvent 
sync_event,
+Status ExportDeviceRecordBatch(const RecordBatch& batch,
+                               std::shared_ptr<Device::SyncEvent> sync,
                                struct ArrowDeviceArray* out,
                                struct ArrowSchema* out_schema = NULLPTR);
 
diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc
index 5c7de8e4a0..9727403163 100644
--- a/cpp/src/arrow/c/bridge_test.cc
+++ b/cpp/src/arrow/c/bridge_test.cc
@@ -1135,12 +1135,49 @@ TEST_F(TestArrayExport, ExportRecordBatch) {
 
 static const char kMyDeviceTypeName[] = "arrowtest::MyDevice";
 static const ArrowDeviceType kMyDeviceType = ARROW_DEVICE_EXT_DEV;
+static const void* kMyEventPtr = 
reinterpret_cast<void*>(uintptr_t(0xBAADF00D));
 
 class MyBuffer final : public MutableBuffer {
  public:
   using MutableBuffer::MutableBuffer;
 
   ~MyBuffer() { default_memory_pool()->Free(const_cast<uint8_t*>(data_), 
size_); }
+
+  std::shared_ptr<Device::SyncEvent> device_sync_event() override { return 
device_sync_; }
+
+ protected:
+  std::shared_ptr<Device::SyncEvent> device_sync_;
+};
+
+class MyDevice : public Device {
+ public:
+  explicit MyDevice(int64_t value) : Device(true), value_(value) {}
+  const char* type_name() const override { return kMyDeviceTypeName; }
+  std::string ToString() const override { return kMyDeviceTypeName; }
+  bool Equals(const Device& other) const override {
+    if (other.type_name() != kMyDeviceTypeName || other.device_type() != 
device_type()) {
+      return false;
+    }
+    return checked_cast<const MyDevice&>(other).value_ == value_;
+  }
+  DeviceAllocationType device_type() const override {
+    return static_cast<DeviceAllocationType>(kMyDeviceType);
+  }
+  int64_t device_id() const override { return value_; }
+  std::shared_ptr<MemoryManager> default_memory_manager() override;
+
+  class MySyncEvent final : public Device::SyncEvent {
+   public:
+    explicit MySyncEvent(void* sync_event, release_fn_t release_sync_event)
+        : Device::SyncEvent(sync_event, release_sync_event) {}
+
+    virtual ~MySyncEvent() = default;
+    Status Wait() override { return Status::OK(); }
+    Status Record(const Device::Stream&) override { return Status::OK(); }
+  };
+
+ protected:
+  int64_t value_;
 };
 
 class MyMemoryManager : public CPUMemoryManager {
@@ -1154,6 +1191,16 @@ class MyMemoryManager : public CPUMemoryManager {
     return std::make_unique<MyBuffer>(data, size, shared_from_this());
   }
 
+  Result<std::shared_ptr<Device::SyncEvent>> MakeDeviceSyncEvent() override {
+    return 
std::make_shared<MyDevice::MySyncEvent>(const_cast<void*>(kMyEventPtr),
+                                                   [](void*) {});
+  }
+
+  Result<std::shared_ptr<Device::SyncEvent>> WrapDeviceSyncEvent(
+      void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) 
override {
+    return std::make_shared<MyDevice::MySyncEvent>(sync_event, 
release_sync_event);
+  }
+
  protected:
   Result<std::shared_ptr<Buffer>> CopyBufferFrom(
       const std::shared_ptr<Buffer>& buf,
@@ -1174,28 +1221,9 @@ class MyMemoryManager : public CPUMemoryManager {
   }
 };
 
-class MyDevice : public Device {
- public:
-  explicit MyDevice(int value) : Device(true), value_(value) {}
-  const char* type_name() const override { return kMyDeviceTypeName; }
-  std::string ToString() const override { return kMyDeviceTypeName; }
-  bool Equals(const Device& other) const override {
-    if (other.type_name() != kMyDeviceTypeName || other.device_type() != 
device_type()) {
-      return false;
-    }
-    return checked_cast<const MyDevice&>(other).value_ == value_;
-  }
-  DeviceAllocationType device_type() const override {
-    return static_cast<DeviceAllocationType>(kMyDeviceType);
-  }
-  int64_t device_id() const override { return value_; }
-  std::shared_ptr<MemoryManager> default_memory_manager() override {
-    return std::make_shared<MyMemoryManager>(shared_from_this());
-  }
-
- protected:
-  int value_;
-};
+std::shared_ptr<MemoryManager> MyDevice::default_memory_manager() {
+  return std::make_shared<MyMemoryManager>(shared_from_this());
+}
 
 class TestDeviceArrayExport : public ::testing::Test {
  public:
@@ -1251,7 +1279,8 @@ class TestDeviceArrayExport : public ::testing::Test {
                        ", array data = ", arr->ToString());
     const ArrayData& data = *arr->data();  // non-owning reference
     struct ArrowDeviceArray c_export;
-    ASSERT_OK(ExportDeviceArray(*arr, {nullptr, nullptr}, &c_export));
+    std::shared_ptr<Device::SyncEvent> sync{nullptr};
+    ASSERT_OK(ExportDeviceArray(*arr, sync, &c_export));
 
     ArrayExportGuard guard(&c_export.array);
     auto new_bytes = pool_->bytes_allocated();
@@ -1455,7 +1484,8 @@ TEST_F(TestDeviceArrayExport, ExportArrayAndType) {
   ArrayExportGuard array_guard(&c_array.array);
 
   auto array = ToDevice(mm, *ArrayFromJSON(int8(), "[1, 2, 
3]")->data()).ValueOrDie();
-  ASSERT_OK(ExportDeviceArray(*array, {nullptr, nullptr}, &c_array, 
&c_schema));
+  auto sync = mm->MakeDeviceSyncEvent().ValueOrDie();
+  ASSERT_OK(ExportDeviceArray(*array, sync, &c_array, &c_schema));
   const ArrayData& data = *array->data();
   array.reset();
   ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
@@ -1463,7 +1493,7 @@ TEST_F(TestDeviceArrayExport, ExportArrayAndType) {
   ASSERT_EQ(c_schema.format, std::string("c"));
   ASSERT_EQ(c_schema.n_children, 0);
   ArrayExportChecker checker{};
-  checker(&c_array, data, kMyDeviceType, 1, nullptr);
+  checker(&c_array, data, kMyDeviceType, 1, kMyEventPtr);
 }
 
 TEST_F(TestDeviceArrayExport, ExportRecordBatch) {
@@ -1481,25 +1511,25 @@ TEST_F(TestDeviceArrayExport, ExportRecordBatch) {
                   .ValueOrDie();
 
   auto batch_factory = [&]() { return RecordBatch::Make(schema, 3, {arr0, 
arr1}); };
-
+  auto sync = mm->MakeDeviceSyncEvent().ValueOrDie();
   {
     auto batch = batch_factory();
 
-    ASSERT_OK(ExportDeviceRecordBatch(*batch, {nullptr, nullptr}, &c_array, 
&c_schema));
+    ASSERT_OK(ExportDeviceRecordBatch(*batch, sync, &c_array, &c_schema));
     SchemaExportGuard schema_guard(&c_schema);
     ArrayExportGuard array_guard(&c_array.array);
     RecordBatchExportChecker checker{};
-    checker(&c_array, *batch, kMyDeviceType, 1, nullptr);
+    checker(&c_array, *batch, kMyDeviceType, 1, kMyEventPtr);
 
     // create batch anew, with the same buffer pointers
     batch = batch_factory();
-    checker(&c_array, *batch, kMyDeviceType, 1, nullptr);
+    checker(&c_array, *batch, kMyDeviceType, 1, kMyEventPtr);
   }
   {
     // Check one can export both schema and record batch at once
     auto batch = batch_factory();
 
-    ASSERT_OK(ExportDeviceRecordBatch(*batch, {nullptr, nullptr}, &c_array, 
&c_schema));
+    ASSERT_OK(ExportDeviceRecordBatch(*batch, sync, &c_array, &c_schema));
     SchemaExportGuard schema_guard(&c_schema);
     ArrayExportGuard array_guard(&c_array.array);
     ASSERT_EQ(c_schema.format, std::string("+s"));
@@ -1508,11 +1538,11 @@ TEST_F(TestDeviceArrayExport, ExportRecordBatch) {
     ASSERT_EQ(kEncodedMetadata2,
               std::string(c_schema.metadata, kEncodedMetadata2.size()));
     RecordBatchExportChecker checker{};
-    checker(&c_array, *batch, kMyDeviceType, 1, nullptr);
+    checker(&c_array, *batch, kMyDeviceType, 1, kMyEventPtr);
 
     // Create batch anew, with the same buffer pointers
     batch = batch_factory();
-    checker(&c_array, *batch, kMyDeviceType, 1, nullptr);
+    checker(&c_array, *batch, kMyDeviceType, 1, kMyEventPtr);
   }
 }
 
@@ -3552,6 +3582,190 @@ TEST_F(TestArrayRoundtrip, RecordBatch) {
   }
 }
 
+class TestDeviceArrayRoundtrip : public ::testing::Test {
+ public:
+  using ArrayFactory = std::function<Result<std::shared_ptr<Array>>()>;
+
+  void SetUp() override { pool_ = default_memory_pool(); }
+
+  static Result<std::shared_ptr<MemoryManager>> DeviceMapper(ArrowDeviceType 
type,
+                                                             int64_t id) {
+    if (type != kMyDeviceType) {
+      return Status::NotImplemented("should only be MyDevice");
+    }
+
+    std::shared_ptr<Device> device = std::make_shared<MyDevice>(id);
+    return device->default_memory_manager();
+  }
+
+  static Result<std::shared_ptr<ArrayData>> ToDeviceData(
+      const std::shared_ptr<MemoryManager>& mm, const ArrayData& data) {
+    arrow::BufferVector buffers;
+    for (const auto& buf : data.buffers) {
+      if (buf) {
+        ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm));
+        buffers.push_back(dest);
+      } else {
+        buffers.push_back(nullptr);
+      }
+    }
+
+    arrow::ArrayDataVector children;
+    for (const auto& child : data.child_data) {
+      ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child));
+      children.push_back(dest);
+    }
+
+    return ArrayData::Make(data.type, data.length, buffers, children, 
data.null_count,
+                           data.offset);
+  }
+
+  static Result<std::shared_ptr<Array>> ToDevice(const 
std::shared_ptr<MemoryManager>& mm,
+                                                 const ArrayData& data) {
+    ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data));
+    return MakeArray(result);
+  }
+
+  static ArrayFactory ToDeviceFactory(const std::shared_ptr<MemoryManager>& mm,
+                                      ArrayFactory&& factory) {
+    return [&]() -> Result<std::shared_ptr<Array>> {
+      ARROW_ASSIGN_OR_RAISE(auto arr, factory());
+      return ToDevice(mm, *arr->data());
+    };
+  }
+
+  static ArrayFactory JSONArrayFactory(const std::shared_ptr<MemoryManager>& 
mm,
+                                       std::shared_ptr<DataType> type, const 
char* json) {
+    return [=]() { return ToDevice(mm, *ArrayFromJSON(type, json)->data()); };
+  }
+
+  static ArrayFactory SlicedArrayFactory(ArrayFactory factory) {
+    return [=]() -> Result<std::shared_ptr<Array>> {
+      ARROW_ASSIGN_OR_RAISE(auto arr, factory());
+      DCHECK_GE(arr->length(), 2);
+      return arr->Slice(1, arr->length() - 2);
+    };
+  }
+
+  template <typename ArrayFactory>
+  void TestWithArrayFactory(ArrayFactory&& factory) {
+    TestWithArrayFactory(factory, factory);
+  }
+
+  template <typename ArrayFactory, typename ExpectedArrayFactory>
+  void TestWithArrayFactory(ArrayFactory&& factory,
+                            ExpectedArrayFactory&& factory_expected) {
+    std::shared_ptr<Array> array;
+    struct ArrowDeviceArray c_array {};
+    struct ArrowSchema c_schema {};
+    ArrayExportGuard array_guard(&c_array.array);
+    SchemaExportGuard schema_guard(&c_schema);
+
+    auto orig_bytes = pool_->bytes_allocated();
+
+    ASSERT_OK_AND_ASSIGN(array, ToResult(factory()));
+    ASSERT_OK(ExportType(*array->type(), &c_schema));
+    std::shared_ptr<Device::SyncEvent> sync{nullptr};
+    ASSERT_OK(ExportDeviceArray(*array, sync, &c_array));
+
+    auto new_bytes = pool_->bytes_allocated();
+    if (array->type_id() != Type::NA) {
+      ASSERT_GT(new_bytes, orig_bytes);
+    }
+
+    array.reset();
+    ASSERT_EQ(pool_->bytes_allocated(), new_bytes);
+    ASSERT_OK_AND_ASSIGN(array, ImportDeviceArray(&c_array, &c_schema, 
DeviceMapper));
+    ASSERT_OK(array->ValidateFull());
+    ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
+    ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));
+
+    // Re-export and re-import, now both at once
+    ASSERT_OK(ExportDeviceArray(*array, sync, &c_array, &c_schema));
+    array.reset();
+    ASSERT_OK_AND_ASSIGN(array, ImportDeviceArray(&c_array, &c_schema, 
DeviceMapper));
+    ASSERT_OK(array->ValidateFull());
+    ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
+    ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));
+
+    // Check value of imported array
+    {
+      std::shared_ptr<Array> expected;
+      ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
+      AssertTypeEqual(*expected->type(), *array->type());
+      AssertArraysEqual(*expected, *array, true);
+    }
+    array.reset();
+    ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+  }
+
+  template <typename BatchFactory>
+  void TestWithBatchFactory(BatchFactory&& factory) {
+    std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+    auto mm = device->default_memory_manager();
+
+    std::shared_ptr<RecordBatch> batch;
+    struct ArrowDeviceArray c_array {};
+    struct ArrowSchema c_schema {};
+    ArrayExportGuard array_guard(&c_array.array);
+    SchemaExportGuard schema_guard(&c_schema);
+
+    auto orig_bytes = pool_->bytes_allocated();
+    ASSERT_OK_AND_ASSIGN(batch, ToResult(factory()));
+    ASSERT_OK(ExportSchema(*batch->schema(), &c_schema));
+    ASSERT_OK_AND_ASSIGN(auto sync, mm->MakeDeviceSyncEvent());
+    ASSERT_OK(ExportDeviceRecordBatch(*batch, sync, &c_array));
+
+    auto new_bytes = pool_->bytes_allocated();
+    batch.reset();
+    ASSERT_EQ(pool_->bytes_allocated(), new_bytes);
+    ASSERT_OK_AND_ASSIGN(batch,
+                         ImportDeviceRecordBatch(&c_array, &c_schema, 
DeviceMapper));
+    ASSERT_OK(batch->ValidateFull());
+    ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
+    ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));
+
+    // Re-export and re-import, now both at once
+    ASSERT_OK(ExportDeviceRecordBatch(*batch, sync, &c_array, &c_schema));
+    batch.reset();
+    ASSERT_OK_AND_ASSIGN(batch,
+                         ImportDeviceRecordBatch(&c_array, &c_schema, 
DeviceMapper));
+    ASSERT_OK(batch->ValidateFull());
+    ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
+    ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));
+
+    // Check value of imported record batch
+    {
+      std::shared_ptr<RecordBatch> expected;
+      ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
+      AssertSchemaEqual(*expected->schema(), *batch->schema());
+      AssertBatchesEqual(*expected, *batch);
+    }
+    batch.reset();
+    ASSERT_EQ(pool_->bytes_allocated(), orig_bytes);
+  }
+
+  void TestWithJSON(const std::shared_ptr<MemoryManager>& mm,
+                    std::shared_ptr<DataType> type, const char* json) {
+    TestWithArrayFactory(JSONArrayFactory(mm, type, json));
+  }
+
+  void TestWithJSONSliced(const std::shared_ptr<MemoryManager>& mm,
+                          std::shared_ptr<DataType> type, const char* json) {
+    TestWithArrayFactory(SlicedArrayFactory(JSONArrayFactory(mm, type, json)));
+  }
+
+ protected:
+  MemoryPool* pool_;
+};
+
+TEST_F(TestDeviceArrayRoundtrip, Primitive) {
+  std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+  auto mm = device->default_memory_manager();
+
+  TestWithJSON(mm, int32(), "[4, 5, null]");
+}
+
 // TODO C -> C++ -> C roundtripping tests?
 
 ////////////////////////////////////////////////////////////////////////////
diff --git a/cpp/src/arrow/device.cc b/cpp/src/arrow/device.cc
index fbb0c3e1a4..14d3bac0af 100644
--- a/cpp/src/arrow/device.cc
+++ b/cpp/src/arrow/device.cc
@@ -29,6 +29,15 @@ namespace arrow {
 
 MemoryManager::~MemoryManager() {}
 
+Result<std::shared_ptr<Device::SyncEvent>> 
MemoryManager::MakeDeviceSyncEvent() {
+  return nullptr;
+}
+
+Result<std::shared_ptr<Device::SyncEvent>> MemoryManager::WrapDeviceSyncEvent(
+    void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) {
+  return nullptr;
+}
+
 Device::~Device() {}
 
 #define COPY_BUFFER_SUCCESS(maybe_buffer) \
diff --git a/cpp/src/arrow/device.h b/cpp/src/arrow/device.h
index 9cc68fe8c8..55037ac418 100644
--- a/cpp/src/arrow/device.h
+++ b/cpp/src/arrow/device.h
@@ -22,6 +22,8 @@
 #include <string>
 
 #include "arrow/io/type_fwd.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
 #include "arrow/type_fwd.h"
 #include "arrow/util/compare.h"
 #include "arrow/util/macros.h"
@@ -98,6 +100,54 @@ class ARROW_EXPORT Device : public 
std::enable_shared_from_this<Device>,
   /// \brief Return the DeviceAllocationType of this device
   virtual DeviceAllocationType device_type() const = 0;
 
+  class SyncEvent;
+
+  /// \brief EXPERIMENTAL: An opaque wrapper for Device-specific streams
+  ///
+  /// In essence this is just a wrapper around a void* to represent the
+  /// standard concept of a stream/queue on a device. Derived classes
+  /// should be trivially constructible from it's device-specific counterparts.
+  class ARROW_EXPORT Stream {
+   public:
+    virtual const void* get_raw() const { return NULLPTR; }
+
+    /// \brief Make the stream wait on the provided event.
+    ///
+    /// Tells the stream that it should wait until the synchronization
+    /// event is completed without blocking the CPU.
+    virtual Status WaitEvent(const SyncEvent&) = 0;
+
+   protected:
+    Stream() = default;
+    virtual ~Stream() = default;
+  };
+
+  /// \brief EXPERIMENTAL: An object that provides event/stream sync primitives
+  class ARROW_EXPORT SyncEvent {
+   public:
+    using release_fn_t = void (*)(void*);
+
+    virtual ~SyncEvent() = default;
+
+    void* get_raw() { return sync_event_.get(); }
+
+    /// @brief Block until sync event is completed.
+    virtual Status Wait() = 0;
+
+    /// @brief Record the wrapped event on the stream so it triggers
+    /// the event when the stream gets to that point in its queue.
+    virtual Status Record(const Stream&) = 0;
+
+   protected:
+    /// If creating this with a passed in event, the caller must ensure
+    /// that the event lives until clear_event is called on this as it
+    /// won't own it.
+    explicit SyncEvent(void* sync_event, release_fn_t release_sync_event)
+        : sync_event_{sync_event, release_sync_event} {}
+
+    std::unique_ptr<void, release_fn_t> sync_event_;
+  };
+
  protected:
   ARROW_DISALLOW_COPY_AND_ASSIGN(Device);
   explicit Device(bool is_cpu = false) : is_cpu_(is_cpu) {}
@@ -165,6 +215,22 @@ class ARROW_EXPORT MemoryManager : public 
std::enable_shared_from_this<MemoryMan
   static Result<std::shared_ptr<Buffer>> ViewBuffer(
       const std::shared_ptr<Buffer>& source, const 
std::shared_ptr<MemoryManager>& to);
 
+  /// \brief Create a new SyncEvent.
+  ///
+  /// This version should construct the appropriate event for the device and
+  /// provide the unique_ptr with the correct deleter for the event type.
+  /// If the device does not require or work with any synchronization, it is
+  /// allowed for it to return a nullptr.
+  virtual Result<std::shared_ptr<Device::SyncEvent>> MakeDeviceSyncEvent();
+
+  /// \brief Wrap an event into a SyncEvent.
+  ///
+  /// @param sync_event passed in sync_event from the imported device array.
+  /// @param release_sync_event destructor to free sync_event. `nullptr` may be
+  ///        passed to indicate that no destruction/freeing is necessary
+  virtual Result<std::shared_ptr<Device::SyncEvent>> WrapDeviceSyncEvent(
+      void* sync_event, Device::SyncEvent::release_fn_t release_sync_event);
+
  protected:
   ARROW_DISALLOW_COPY_AND_ASSIGN(MemoryManager);
 
diff --git a/cpp/src/arrow/gpu/cuda_context.cc 
b/cpp/src/arrow/gpu/cuda_context.cc
index 869ea6453c..3e1af26cac 100644
--- a/cpp/src/arrow/gpu/cuda_context.cc
+++ b/cpp/src/arrow/gpu/cuda_context.cc
@@ -293,6 +293,13 @@ std::shared_ptr<CudaDevice> 
CudaMemoryManager::cuda_device() const {
   return checked_pointer_cast<CudaDevice>(device_);
 }
 
+Result<std::shared_ptr<Device::SyncEvent>> 
CudaMemoryManager::WrapDeviceSyncEvent(
+    void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) {
+  return nullptr;
+  // auto ev = reinterpret_cast<CUstream*>(sync_event);
+  // return std::make_shared<CudaDeviceSync>(ev);
+}
+
 Result<std::shared_ptr<io::RandomAccessFile>> 
CudaMemoryManager::GetBufferReader(
     std::shared_ptr<Buffer> buf) {
   if (*buf->device() != *device_) {
diff --git a/cpp/src/arrow/gpu/cuda_context.h b/cpp/src/arrow/gpu/cuda_context.h
index a1b95c7b41..79a2ec9f97 100644
--- a/cpp/src/arrow/gpu/cuda_context.h
+++ b/cpp/src/arrow/gpu/cuda_context.h
@@ -179,6 +179,9 @@ class ARROW_EXPORT CudaMemoryManager : public MemoryManager 
{
   /// having to cast the `device()` result.
   std::shared_ptr<CudaDevice> cuda_device() const;
 
+  Result<std::shared_ptr<Device::SyncEvent>> WrapDeviceSyncEvent(
+      void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) 
override;
+
  protected:
   using MemoryManager::MemoryManager;
   static std::shared_ptr<CudaMemoryManager> Make(const 
std::shared_ptr<Device>& device);

Reply via email to