This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a84adaf0ff [CudaGraph] Handle exceptions thrown while capturing cuda 
graph (#17113)
a84adaf0ff is described below

commit a84adaf0ff39a40ab4cd0867972b805c4733ca10
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jun 27 16:07:22 2024 -0500

    [CudaGraph] Handle exceptions thrown while capturing cuda graph (#17113)
    
    * [CudaGraph] Handle exceptions thrown while capturing cuda graph
    
    Prior to this commit, an exception thrown during the capture of a cuda
    graph would result in `std::terminate` being called.  This commit
    updates the implementation of `"vm.builtin.cuda_graph.run_or_capture"`
    such that a thrown exception can be recovered from, and does not cause
    any changes to the state of TVM's cuda graph cache.
    
    - Call to `cudaStreamDestroy` was previously skipped, now moved to a
      RAII-style destructor in a `ScopedCUDAStream` class.
    
    - Call to `cudaStreamEndCapture` was previously skipped, end of cuda
      graph capture now performed as part of RAII-style destructor for
      `CUDACaptureStream` class.
    
    - Restoration of `CUDAThreadEntry::ThreadLocal()->stream` was
      previously skipped, now restored as part of RAII-style destructor
      for `CUDACaptureStream` class.
    
    - Previously, an error raised from `cudaGraphInstantiate` would leave
      the `capture_cache_` in an ill-formed state.  Now, the
      `capture_cache_` is only updated after a valid
      `CUDAGraphCapturedState` has been fully constructed.
    
    * lint fix
    
    * Unit test fix
---
 src/runtime/relax_vm/cuda/cuda_graph_builtin.cc | 81 ++++++++++++++++++++-----
 tests/python/relax/test_vm_cuda_graph.py        | 77 ++++++++++++++++++++++-
 2 files changed, 140 insertions(+), 18 deletions(-)

diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc 
b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index dea497e4a9..e8901c0f19 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -32,6 +32,8 @@ namespace tvm {
 namespace runtime {
 namespace relax_vm {
 
+namespace {
+
 struct CUDAGraphCaptureKey {
   // The unique index of the capture function within the module
   int64_t index;
@@ -67,6 +69,18 @@ struct CUDAGraphCaptureKeyEqual {
 
 /*! \brief The captured state of a CUDA graph */
 struct CUDAGraphCapturedState {
+  CUDAGraphCapturedState() {}
+
+  CUDAGraphCapturedState(const CUDAGraphCapturedState&) = delete;
+  CUDAGraphCapturedState(CUDAGraphCapturedState&& other) { *this = 
std::move(other); }
+
+  CUDAGraphCapturedState& operator=(const CUDAGraphCapturedState&) = delete;
+  CUDAGraphCapturedState& operator=(CUDAGraphCapturedState&& other) {
+    std::swap(states, other.states);
+    std::swap(exec, other.exec);
+    return *this;
+  }
+
   ~CUDAGraphCapturedState() {
     if (exec) {
       CUDA_CALL(cudaGraphExecDestroy(exec));
@@ -82,6 +96,43 @@ struct CUDAGraphCapturedState {
   cudaGraphExec_t exec = nullptr;
 };
 
+class ScopedCUDAStream {
+ public:
+  ScopedCUDAStream() { CUDA_CALL(cudaStreamCreate(&stream_)); }
+  ~ScopedCUDAStream() { cudaStreamDestroy(stream_); }
+  ScopedCUDAStream(const ScopedCUDAStream&) = delete;
+  ScopedCUDAStream(ScopedCUDAStream&&) = delete;
+  ScopedCUDAStream& operator=(const ScopedCUDAStream&) = delete;
+  ScopedCUDAStream& operator=(ScopedCUDAStream&&) = delete;
+
+  operator cudaStream_t() const { return stream_; }
+
+ private:
+  cudaStream_t stream_;
+};
+
+class CUDACaptureStream {
+ public:
+  explicit CUDACaptureStream(cudaGraph_t* graph)
+      : prev_default_stream_(CUDAThreadEntry::ThreadLocal()->stream), 
output_graph_(graph) {
+    CUDAThreadEntry::ThreadLocal()->stream = capture_stream_;
+
+    CUDA_CALL(cudaStreamBeginCapture(capture_stream_, 
cudaStreamCaptureModeGlobal));
+  }
+  ~CUDACaptureStream() {
+    cudaStreamEndCapture(capture_stream_, output_graph_);
+    CUDAThreadEntry::ThreadLocal()->stream = prev_default_stream_;
+  }
+
+ private:
+  cudaStream_t prev_default_stream_;
+  ScopedCUDAStream capture_stream_;
+
+  cudaGraph_t* output_graph_;
+};
+
+}  // namespace
+
 /*! \brief The VM extension of CUDA graph. */
 class CUDAGraphExtensionNode : public VMExtensionNode {
  public:
@@ -107,10 +158,6 @@ class CUDAGraphExtensionNode : public VMExtensionNode {
       return states;
     }
 
-    cudaStream_t capture_stream;
-    CUDA_CALL(cudaStreamCreate(&capture_stream));
-    CUDAGraphCapturedState entry;
-
     // Set up arguments for the graph execution
     Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
     int nargs = static_cast<int>(tuple_args.size());
@@ -130,21 +177,23 @@ class CUDAGraphExtensionNode : public VMExtensionNode {
 
     // Run the graph in capture mode
     cudaGraph_t graph;
-    std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
-    CUDA_CALL(cudaStreamBeginCapture(CUDAThreadEntry::ThreadLocal()->stream,
-                                     cudaStreamCaptureModeGlobal));
 
-    vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), 
tcodes.data(), nargs),
-                            &capture_func_rv);
-    entry.states = capture_func_rv;
-    CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, 
&graph));
-    std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
+    {
+      CUDACaptureStream capture_stream(&graph);
+      vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), 
tcodes.data(), nargs),
+                              &capture_func_rv);
+    }
 
-    capture_cache_[entry_key] = entry;
-    CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_key].exec, graph, 
NULL, NULL, 0));
-    CUDA_CALL(cudaStreamDestroy(capture_stream));
+    CUDAGraphCapturedState entry;
+    entry.states = capture_func_rv;
+    CUDA_CALL(cudaGraphInstantiate(&entry.exec, graph, NULL, NULL, 0));
     CUDA_CALL(cudaGraphDestroy(graph));
-    return entry.states;
+
+    ObjectRef states = entry.states;
+
+    capture_cache_[entry_key] = std::move(entry);
+
+    return states;
   }
 
   /*!
diff --git a/tests/python/relax/test_vm_cuda_graph.py 
b/tests/python/relax/test_vm_cuda_graph.py
index 6a20b6b1f8..49ebcc1d05 100644
--- a/tests/python/relax/test_vm_cuda_graph.py
+++ b/tests/python/relax/test_vm_cuda_graph.py
@@ -16,10 +16,13 @@
 # under the License.
 
 import tvm
-from tvm.script import tir as T, relax as R, ir as I
-from tvm import relax
 import tvm.testing
+
+from tvm import relax
+from tvm.script import tir as T, relax as R, ir as I
+
 import numpy as np
+import pytest
 
 
 # fmt: off
@@ -104,5 +107,75 @@ def test_vm_run():
     tvm.testing.assert_allclose(y.asnumpy(), y_np, rtol=1e-5, atol=1e-5)
 
 
[email protected]_cudagraph
+def test_capture_error_is_recoverable():
+    """Function calls while capturing cudagraph may throw exceptions
+
+    Calls to PackedFuncs may occur within a captured cudaGraph.  If a
+    call to that PackedFunc raises an exception while capturing the
+    cudaGraph, throwing exception should cleanly unwind the stack, and
+    the exception may be caught in the calling scope.
+
+    This is a regression test.  In previous implementations, an
+    exception thrown while capturing a cudaGraph would skip the call
+    to `cudaStreamEndCapture`, causing additional exceptions to be
+    thrown while freeing memory in TVM destructors.  Since C++ does
+    not support stack unwinding from multiple simultaneous exceptions,
+    this would result in immediate `std::terminate`, making it
+    difficult to debug the original error.
+
+    """
+
+    target = tvm.target.Target("cuda")
+    dev = tvm.cuda()
+
+    @tvm.register_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", 
override=True)
+    def invalid_impl_for_cudagraph(arg_tensor):
+        # Memory allocation/deallocation may not be performed while
+        # capturing a cudaGraph.  This passes the warm-up run
+        # performed by "vm.builtin.cuda_graph.run_or_capture", but
+        # throws an exception when the cudaGraph is being captured.
+        _dummy_workspace = tvm.nd.empty([16], "float16", dev)
+        return arg_tensor
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.add(A, A)
+            C = R.call_pure_packed(
+                "test_vm_cuda_graph.invalid_impl_for_cudagraph",
+                B,
+                sinfo_args=R.Tensor([16], "float16"),
+            )
+            D = R.add(C, C)
+            return D
+
+    with target, 
tvm.ir.transform.PassContext(config={"relax.backend.use_cuda_graph": True}):
+        Module = tvm.ir.transform.Sequential(
+            [
+                tvm.relax.transform.LegalizeOps(),
+                tvm.tir.transform.DefaultGPUSchedule(),
+                tvm.relax.transform.RemovePurityChecking(),
+                tvm.relax.transform.CallTIRRewrite(),
+                tvm.relax.transform.StaticPlanBlockMemory(),
+                tvm.relax.transform.RewriteCUDAGraph(),
+            ]
+        )(Module)
+
+    assert "cuda_graph_alloc" in Module, (
+        "Validity of unit test requires the call to 
`invalid_impl_for_cudagraph` "
+        "to have been captured by RewriteCUDAGraph."
+    )
+
+    built = tvm.relax.build(Module, target=target)
+    vm = tvm.relax.VirtualMachine(built, dev)
+
+    arg = tvm.nd.array(np.arange(16).astype("float16"), dev)
+
+    with pytest.raises(tvm.TVMError):
+        vm["main"](arg)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to