This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fc78b22fbc [Relax][VM] Refactor CUDA graph builtins as VM extension 
(#16823)
fc78b22fbc is described below

commit fc78b22fbc469153f4d50de10891374e2c47f8bc
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Apr 1 15:23:54 2024 -0700

    [Relax][VM] Refactor CUDA graph builtins as VM extension (#16823)
    
    * [Relax][VM] Refactor CUDA graph builtins as VM extension
    
    * skip test
---
 include/tvm/runtime/relax_vm/vm.h                  | 44 ++++++++++++++++
 src/runtime/relax_vm/cuda/cuda_graph_builtin.cc    | 60 +++++++++++++---------
 .../test_relax_2d_buffer_allocation.py             |  2 +
 3 files changed, 83 insertions(+), 23 deletions(-)

diff --git a/include/tvm/runtime/relax_vm/vm.h 
b/include/tvm/runtime/relax_vm/vm.h
index d2c96e9e97..da833d5d6c 100644
--- a/include/tvm/runtime/relax_vm/vm.h
+++ b/include/tvm/runtime/relax_vm/vm.h
@@ -29,6 +29,7 @@
 
 #include <memory>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 #include "../memory/memory_manager.h"
@@ -97,6 +98,27 @@ class VMClosure : public Closure {
   static PackedFunc BindLastArgs(PackedFunc func, std::vector<TVMRetValue> 
last_args);
 };
 
+/*!
+ * \brief Represent a VM extension.
+ * A VM extension allows the user to extend the VM with target specific 
functionalities.
+ * The VM holds the reference of the extensions to ensure the extensions have 
the same lifetime
+ * as the VM.
+ *
+ * This is the base class for all VM extensions and should not be used 
directly.
+ */
+class VMExtensionNode : public Object {
+ protected:
+  static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+  static constexpr const char* _type_key = "runtime.VMExtension";
+  TVM_DECLARE_BASE_OBJECT_INFO(VMExtensionNode, Object);
+};
+
+/*! \brief Managed reference to VM extension. */
+class VMExtension : public ObjectRef {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(VMExtension, ObjectRef, VMExtensionNode);
+};
+
 /*!
  * \brief The virtual machine.
  *
@@ -156,6 +178,25 @@ class VirtualMachine : public runtime::ModuleNode {
    * \param instrument The instrument function.
    */
   virtual void SetInstrument(PackedFunc instrument) = 0;
+
+  /*!
+   * \brief Get or create a VM extension. Once created, the extension will be 
stored in the VM
+   * and held until the VM is destructed.
+   *
+   * \tparam T The type of the extension
+   * \return The extension instance
+   */
+  template <typename T, typename = 
std::enable_if_t<std::is_base_of<VMExtension, T>::value>>
+  T GetOrCreateExtension() {
+    using ContainerType = typename T::ContainerType;
+    uint32_t key = ContainerType::RuntimeTypeIndex();
+    if (auto it = extensions.find(key); it != extensions.end()) {
+      return Downcast<T>((*it).second);
+    }
+    auto [it, _] = extensions.emplace(key, T::Create());
+    return Downcast<T>((*it).second);
+  }
+
   /*!
    * \brief Create a specific instance of VM.
    * \return Created VM
@@ -183,6 +224,9 @@ class VirtualMachine : public runtime::ModuleNode {
   std::vector<Allocator*> allocators;
   /*! \brief Runtime physical device list. */
   std::vector<Device> devices;
+  /*! \brief The VM extensions. Mapping from the type index of the extension 
to the extension
+   * instance. */
+  std::unordered_map<uint32_t, VMExtension> extensions;
 };
 
 }  // namespace relax_vm
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc 
b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index 02b6da7dab..dea497e4a9 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -65,25 +65,27 @@ struct CUDAGraphCaptureKeyEqual {
   }
 };
 
-/*! \brief The cache states of a CUDA graph. */
-class CUDAGraphCache : public Object {
- public:
-  struct CaptureResult {
-    ~CaptureResult() {
-      if (exec) {
-        CUDA_CALL(cudaGraphExecDestroy(exec));
-      }
+/*! \brief The captured state of a CUDA graph */
+struct CUDAGraphCapturedState {
+  ~CUDAGraphCapturedState() {
+    if (exec) {
+      CUDA_CALL(cudaGraphExecDestroy(exec));
     }
-    /*!
-     * \brief Tuple of intemediate tensors in the capture func that will be 
used outside the
-     * capture func
-     */
-    ObjectRef states;
-    /*! \brief The instantiated cuda graph */
-    cudaGraphExec_t exec = nullptr;
-  };
+  }
 
-  static CUDAGraphCache* Get() { return 
dmlc::ThreadLocalStore<CUDAGraphCache>::Get(); }
+  /*!
+   * \brief Tuple of intemediate tensors in the capture func that will be used 
outside the
+   * capture func
+   */
+  ObjectRef states;
+  /*! \brief The instantiated cuda graph */
+  cudaGraphExec_t exec = nullptr;
+};
+
+/*! \brief The VM extension of CUDA graph. */
+class CUDAGraphExtensionNode : public VMExtensionNode {
+ public:
+  TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphExtensionNode, VMExtensionNode);
 
   /*!
    * \brief Launch the cuda graph if it has been cached, otherwise execute it 
in capture mode.
@@ -107,7 +109,7 @@ class CUDAGraphCache : public Object {
 
     cudaStream_t capture_stream;
     CUDA_CALL(cudaStreamCreate(&capture_stream));
-    CUDAGraphCache::CaptureResult entry;
+    CUDAGraphCapturedState entry;
 
     // Set up arguments for the graph execution
     Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
@@ -164,12 +166,14 @@ class CUDAGraphCache : public Object {
     return alloc_result;
   }
 
+  static constexpr const char* _type_key = "relax_vm.CUDAGraphExtension";
+
  private:
   /*!
    * \brief The cache of captured cuda graphs. The key is a unique index for 
the capture function.
    * The value is the result of the capture.
    */
-  std::unordered_map<CUDAGraphCaptureKey, CaptureResult, 
CUDAGraphCaptureKeyHash,
+  std::unordered_map<CUDAGraphCaptureKey, CUDAGraphCapturedState, 
CUDAGraphCaptureKeyHash,
                      CUDAGraphCaptureKeyEqual>
       capture_cache_;
   /*!
@@ -179,10 +183,21 @@ class CUDAGraphCache : public Object {
   std::unordered_map<int64_t, ObjectRef> alloc_cache_;
 };
 
+/*! Managed reference to CUDAGraphExtensionNode */
+class CUDAGraphExtension : public VMExtension {
+ public:
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CUDAGraphExtension, VMExtension, 
CUDAGraphExtensionNode);
+  static CUDAGraphExtension Create() {
+    auto data_ = make_object<CUDAGraphExtensionNode>();
+    return CUDAGraphExtension(std::move(data_));
+  }
+};
+
 TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture")
     .set_body([](TVMArgs args, TVMRetValue* rv) {
       ICHECK(args.size() == 5 || args.size() == 4);
       VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
+      auto extension = vm->GetOrCreateExtension<CUDAGraphExtension>();
       ObjectRef capture_func = args[1];
       ObjectRef func_args = args[2];
       int64_t entry_index = args[3];
@@ -190,18 +205,17 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture")
       if (args.size() == 5) {
         shape_expr = args[4].AsObjectRef<ShapeTuple>();
       }
-      CUDAGraphCache* cache = CUDAGraphCache::Get();
-      *rv = cache->RunOrCapture(vm, capture_func, func_args, entry_index, 
shape_expr);
+      *rv = extension->RunOrCapture(vm, capture_func, func_args, entry_index, 
shape_expr);
     });
 
 TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc")
     .set_body([](TVMArgs args, TVMRetValue* rv) {
       ICHECK_EQ(args.size(), 3);
       VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
+      auto extension = vm->GetOrCreateExtension<CUDAGraphExtension>();
       ObjectRef alloc_func = args[1];
       int64_t entry_index = args[2];
-      CUDAGraphCache* cache = CUDAGraphCache::Get();
-      *rv = cache->GetCachedAllocation(vm, alloc_func, entry_index);
+      *rv = extension->GetCachedAllocation(vm, alloc_func, entry_index);
     });
 
 }  // namespace relax_vm
diff --git 
a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py 
b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
index ae459dc770..6eaa1179ba 100644
--- a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
+++ b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
@@ -25,6 +25,7 @@ from tvm import relax
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
+import pytest
 
 
 # pylint: disable=missing-docstring,no-self-argument,invalid-name
@@ -64,6 +65,7 @@ class Module:
 
 
 # pylint: enable=missing-docstring,no-self-argument,invalid-name
[email protected]
 def test_alloc_storage_with_scope_global(hexagon_launcher):
     """
     Test 2d allocation to global.vtcm memory scope in a Relax Function

Reply via email to