This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 97ff7cc4f1 [VM][OPENCL] Take advantage of OpenCL host ptr for improved 
copy (#16929)
97ff7cc4f1 is described below

commit 97ff7cc4f197ef0fa21093448dd3e45e6f1fd2bc
Author: Siva <[email protected]>
AuthorDate: Sat Apr 27 02:07:44 2024 +0530

    [VM][OPENCL] Take advantage of OpenCL host ptr for improved copy (#16929)
    
    We can use OpenCL mapped pointer for these copies for
    improved performance.
---
 src/runtime/relax_vm/paged_kv_cache.cc | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 64759d465b..efedac235b 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -31,6 +31,9 @@
 #include <vector>
 
 #include "kv_state.h"
+#if defined(OPENCL_ENABLE_HOST_PTR)
+#include "../opencl/opencl_common.h"
+#endif
 
 namespace tvm {
 namespace runtime {
@@ -384,6 +387,22 @@ class PlainPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
       return;
     }
     DLTensor copy_dst = *array.operator->();
+#if defined(OPENCL_ENABLE_HOST_PTR)
+    tvm::runtime::cl::OpenCLWorkspace* workspace = 
tvm::runtime::cl::OpenCLWorkspace::Global();
+    if (workspace->IsOpenCLDevice(copy_dst.device)) {
+      void* nptr = workspace->GetNativePtr(array);
+      uint64_t copy_size;
+      if (shape.defined()) {
+        ICHECK_EQ(shape.value().size(), 1);
+        copy_size = shape.value()->data[0] * sizeof(int32_t);
+      } else {
+        copy_size = 
DeviceAPI::Get(array->device)->GetDataSize(*array.operator->());
+      }
+      memcpy(static_cast<char*>(nptr) + dst_elem_offset * sizeof(int32_t), 
vec_data, copy_size);
+      return;
+    }
+#endif
+
     if (shape.defined()) {
       ICHECK_EQ(shape.value().size(), 1);
       copy_dst.ndim = 1;

Reply via email to