This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fc78b22fbc [Relax][VM] Refactor CUDA graph builtins as VM extension
(#16823)
fc78b22fbc is described below
commit fc78b22fbc469153f4d50de10891374e2c47f8bc
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Apr 1 15:23:54 2024 -0700
[Relax][VM] Refactor CUDA graph builtins as VM extension (#16823)
* [Relax][VM] Refactor CUDA graph builtins as VM extension
* skip test
---
include/tvm/runtime/relax_vm/vm.h | 44 ++++++++++++++++
src/runtime/relax_vm/cuda/cuda_graph_builtin.cc | 60 +++++++++++++---------
.../test_relax_2d_buffer_allocation.py | 2 +
3 files changed, 83 insertions(+), 23 deletions(-)
diff --git a/include/tvm/runtime/relax_vm/vm.h
b/include/tvm/runtime/relax_vm/vm.h
index d2c96e9e97..da833d5d6c 100644
--- a/include/tvm/runtime/relax_vm/vm.h
+++ b/include/tvm/runtime/relax_vm/vm.h
@@ -29,6 +29,7 @@
#include <memory>
#include <string>
+#include <unordered_map>
#include <vector>
#include "../memory/memory_manager.h"
@@ -97,6 +98,27 @@ class VMClosure : public Closure {
static PackedFunc BindLastArgs(PackedFunc func, std::vector<TVMRetValue>
last_args);
};
+/*!
+ * \brief Represent a VM extension.
+ * A VM extension allows the user to extend the VM with target specific
functionalities.
+ * The VM holds the reference of the extensions to ensure the extensions have
the same lifetime
+ * as the VM.
+ *
+ * This is the base class for all VM extensions and should not be used
directly.
+ */
+class VMExtensionNode : public Object {
+ protected:
+ static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+ static constexpr const char* _type_key = "runtime.VMExtension";
+ TVM_DECLARE_BASE_OBJECT_INFO(VMExtensionNode, Object);
+};
+
+/*! \brief Managed reference to VM extension. */
+class VMExtension : public ObjectRef {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(VMExtension, ObjectRef, VMExtensionNode);
+};
+
/*!
* \brief The virtual machine.
*
@@ -156,6 +178,25 @@ class VirtualMachine : public runtime::ModuleNode {
* \param instrument The instrument function.
*/
virtual void SetInstrument(PackedFunc instrument) = 0;
+
+ /*!
+ * \brief Get or create a VM extension. Once created, the extension will be
stored in the VM
+ * and held until the VM is destructed.
+ *
+ * \tparam T The type of the extension
+ * \return The extension instance
+ */
+ template <typename T, typename =
std::enable_if_t<std::is_base_of<VMExtension, T>::value>>
+ T GetOrCreateExtension() {
+ using ContainerType = typename T::ContainerType;
+ uint32_t key = ContainerType::RuntimeTypeIndex();
+ if (auto it = extensions.find(key); it != extensions.end()) {
+ return Downcast<T>((*it).second);
+ }
+ auto [it, _] = extensions.emplace(key, T::Create());
+ return Downcast<T>((*it).second);
+ }
+
/*!
* \brief Create a specific instance of VM.
* \return Created VM
@@ -183,6 +224,9 @@ class VirtualMachine : public runtime::ModuleNode {
std::vector<Allocator*> allocators;
/*! \brief Runtime physical device list. */
std::vector<Device> devices;
+ /*! \brief The VM extensions. Mapping from the type index of the extension
to the extension
+ * instance. */
+ std::unordered_map<uint32_t, VMExtension> extensions;
};
} // namespace relax_vm
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index 02b6da7dab..dea497e4a9 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -65,25 +65,27 @@ struct CUDAGraphCaptureKeyEqual {
}
};
-/*! \brief The cache states of a CUDA graph. */
-class CUDAGraphCache : public Object {
- public:
- struct CaptureResult {
- ~CaptureResult() {
- if (exec) {
- CUDA_CALL(cudaGraphExecDestroy(exec));
- }
+/*! \brief The captured state of a CUDA graph */
+struct CUDAGraphCapturedState {
+ ~CUDAGraphCapturedState() {
+ if (exec) {
+ CUDA_CALL(cudaGraphExecDestroy(exec));
}
- /*!
- * \brief Tuple of intemediate tensors in the capture func that will be
used outside the
- * capture func
- */
- ObjectRef states;
- /*! \brief The instantiated cuda graph */
- cudaGraphExec_t exec = nullptr;
- };
+ }
- static CUDAGraphCache* Get() { return
dmlc::ThreadLocalStore<CUDAGraphCache>::Get(); }
+ /*!
+ * \brief Tuple of intemediate tensors in the capture func that will be used
outside the
+ * capture func
+ */
+ ObjectRef states;
+ /*! \brief The instantiated cuda graph */
+ cudaGraphExec_t exec = nullptr;
+};
+
+/*! \brief The VM extension of CUDA graph. */
+class CUDAGraphExtensionNode : public VMExtensionNode {
+ public:
+ TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphExtensionNode, VMExtensionNode);
/*!
* \brief Launch the cuda graph if it has been cached, otherwise execute it
in capture mode.
@@ -107,7 +109,7 @@ class CUDAGraphCache : public Object {
cudaStream_t capture_stream;
CUDA_CALL(cudaStreamCreate(&capture_stream));
- CUDAGraphCache::CaptureResult entry;
+ CUDAGraphCapturedState entry;
// Set up arguments for the graph execution
Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
@@ -164,12 +166,14 @@ class CUDAGraphCache : public Object {
return alloc_result;
}
+ static constexpr const char* _type_key = "relax_vm.CUDAGraphExtension";
+
private:
/*!
* \brief The cache of captured cuda graphs. The key is a unique index for
the capture function.
* The value is the result of the capture.
*/
- std::unordered_map<CUDAGraphCaptureKey, CaptureResult,
CUDAGraphCaptureKeyHash,
+ std::unordered_map<CUDAGraphCaptureKey, CUDAGraphCapturedState,
CUDAGraphCaptureKeyHash,
CUDAGraphCaptureKeyEqual>
capture_cache_;
/*!
@@ -179,10 +183,21 @@ class CUDAGraphCache : public Object {
std::unordered_map<int64_t, ObjectRef> alloc_cache_;
};
+/*! Managed reference to CUDAGraphExtensionNode */
+class CUDAGraphExtension : public VMExtension {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CUDAGraphExtension, VMExtension,
CUDAGraphExtensionNode);
+ static CUDAGraphExtension Create() {
+ auto data_ = make_object<CUDAGraphExtensionNode>();
+ return CUDAGraphExtension(std::move(data_));
+ }
+};
+
TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ICHECK(args.size() == 5 || args.size() == 4);
VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
+ auto extension = vm->GetOrCreateExtension<CUDAGraphExtension>();
ObjectRef capture_func = args[1];
ObjectRef func_args = args[2];
int64_t entry_index = args[3];
@@ -190,18 +205,17 @@
TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture")
if (args.size() == 5) {
shape_expr = args[4].AsObjectRef<ShapeTuple>();
}
- CUDAGraphCache* cache = CUDAGraphCache::Get();
- *rv = cache->RunOrCapture(vm, capture_func, func_args, entry_index,
shape_expr);
+ *rv = extension->RunOrCapture(vm, capture_func, func_args, entry_index,
shape_expr);
});
TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 3);
VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
+ auto extension = vm->GetOrCreateExtension<CUDAGraphExtension>();
ObjectRef alloc_func = args[1];
int64_t entry_index = args[2];
- CUDAGraphCache* cache = CUDAGraphCache::Get();
- *rv = cache->GetCachedAllocation(vm, alloc_func, entry_index);
+ *rv = extension->GetCachedAllocation(vm, alloc_func, entry_index);
});
} // namespace relax_vm
diff --git
a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
index ae459dc770..6eaa1179ba 100644
--- a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
+++ b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
@@ -25,6 +25,7 @@ from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
+import pytest
# pylint: disable=missing-docstring,no-self-argument,invalid-name
@@ -64,6 +65,7 @@ class Module:
# pylint: enable=missing-docstring,no-self-argument,invalid-name
[email protected]
def test_alloc_storage_with_scope_global(hexagon_launcher):
"""
Test 2d allocation to global.vtcm memory scope in a Relax Function