This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 071fb8a429 [RUNTIME] Ensure NDArray.CopyTo(Device) always sync (#16716)
071fb8a429 is described below

commit 071fb8a4290ff1c59f6d99d3ccbe051d5a0a1ff6
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Mar 14 09:05:57 2024 -0400

    [RUNTIME] Ensure NDArray.CopyTo(Device) always sync (#16716)
    
    This PR ensures that NDArray.CopyTo(Device) always sync.
    Prior to this PR, the behavior is uncertain as the underlying
    DeviceAPI may or maynot sync. This PR further clarifies in
    docs about the contract (that low-level device api is always async)
    as well as the sync/async nature of each NDArray API.
---
 include/tvm/runtime/device_api.h |  2 ++
 include/tvm/runtime/ndarray.h    | 12 ++----------
 src/runtime/ndarray.cc           | 11 +++++++++++
 3 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 721990c625..b419212602 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -147,6 +147,8 @@ class TVM_DLL DeviceAPI {
    * \param from The source array.
    * \param to The target array.
    * \param stream Optional stream object.
+   * \note The copy may happen asynchronously if it involves a GPU context.
+   *       Call StreamSync to ensure the copy completes from host's pov.
    */
   virtual void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle 
stream);
   /*!
diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h
index 8400344bf5..d643355d26 100644
--- a/include/tvm/runtime/ndarray.h
+++ b/include/tvm/runtime/ndarray.h
@@ -112,8 +112,9 @@ class NDArray : public ObjectRef {
    * \param dev The target device.
    * \param mem_scope The memory scope of the target array.
    * \return The array under another device.
+   * \note The copy always triggers a TVMSynchronize.
    */
-  inline NDArray CopyTo(const Device& dev, Optional<String> mem_scope = 
NullOpt) const;
+  TVM_DLL NDArray CopyTo(const Device& dev, Optional<String> mem_scope = 
NullOpt) const;
   /*!
    * \brief Load NDArray from stream
    * \param stream The input data stream
@@ -399,15 +400,6 @@ inline void NDArray::CopyTo(const NDArray& other) const {
   CopyFromTo(&(get_mutable()->dl_tensor), &(other.get_mutable()->dl_tensor));
 }
 
-inline NDArray NDArray::CopyTo(const Device& dev, Optional<String> mem_scope) 
const {
-  ICHECK(data_ != nullptr);
-  const DLTensor* dptr = operator->();
-  NDArray ret =
-      Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, 
dev, mem_scope);
-  this->CopyTo(ret);
-  return ret;
-}
-
 inline int NDArray::use_count() const { return data_.use_count(); }
 
 inline const DLTensor* NDArray::operator->() const { return 
&(get_mutable()->dl_tensor); }
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index 675ee62a05..6d03e2e01b 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -287,6 +287,17 @@ void NDArray::CopyFromBytes(const void* data, size_t 
nbytes) {
   ArrayCopyFromBytes(&get_mutable()->dl_tensor, data, nbytes);
 }
 
+NDArray NDArray::CopyTo(const Device& dev, Optional<String> mem_scope) const {
+  ICHECK(data_ != nullptr);
+  const DLTensor* dptr = operator->();
+  NDArray ret =
+      Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, 
dev, mem_scope);
+  this->CopyTo(ret);
+  Device copy_gpu_dev = dptr->device.device_type != kDLCPU ? dptr->device : 
dev;
+  DeviceAPI::Get(copy_gpu_dev)->StreamSync(copy_gpu_dev, nullptr);
+  return ret;
+}
+
 void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle 
stream) {
   size_t from_size = GetDataSize(*from);
   size_t to_size = GetDataSize(*to);

Reply via email to