This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f83a32906f [Relax] Share storage allocs among functions after cuda 
graph rewriting (#16830)
f83a32906f is described below

commit f83a32906f9d3765946db0b9bdc31e4eef5072b3
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Apr 1 21:09:55 2024 -0700

    [Relax] Share storage allocs among functions after cuda graph rewriting 
(#16830)
---
 src/relax/transform/rewrite_cuda_graph.cc          | 386 ++++++++++++++++-----
 .../relax/test_transform_rewrite_cuda_graph.py     | 241 ++++++++++++-
 2 files changed, 518 insertions(+), 109 deletions(-)

diff --git a/src/relax/transform/rewrite_cuda_graph.cc 
b/src/relax/transform/rewrite_cuda_graph.cc
index 25b229ebce..d0e20ffd76 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -49,17 +49,19 @@
  * 2. Lift the regions identified in step 1 to a separate function and rewrite 
the original function
  * with `CUDAGraphRewriter`.
  */
-
+#include <tvm/relax/analysis.h>
 #include <tvm/relax/backend.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/expr_functor.h>
 #include <tvm/tir/stmt_functor.h>
 
+#include <unordered_map>
+#include <vector>
+
 #include "../../support/arena.h"
 #include "../../support/ordered_set.h"
 #include "../../support/utils.h"
-
 namespace tvm {
 namespace relax {
 
@@ -79,9 +81,10 @@ struct LiftedFunctionRewritePlan {
   // Variable remappings between the original function and the lifted function
 
   // The bindings in the original function that are lifted
-  std::unordered_set<const VarNode*> lifted_bindings;
+  std::vector<const VarBindingNode*> lifted_bindings;
   // The corresponding binding vars in the original function of the outputs of 
the lifted function
-  std::vector<const VarNode*> outputs;
+  // to the index of the element in the output tuple of the lifted function.
+  std::unordered_map<const VarNode*, int> outputs;
   // The corresponding binding vars in the original function of the inputs of 
the lifted function
   std::vector<const VarNode*> inputs;
   // The tir vars in the original function that are propagated to the lifted 
function
@@ -170,13 +173,68 @@ class FuncBuilder : public ExprMutator {
   Map<tir::Var, PrimExpr> tir_var_remap_;
 };
 
+// Collect the storage objects that are used as the function output
+class OutputStorageCollector : public ExprVisitor {
+ public:
+  static std::unordered_set<const VarNode*> Collect(const Function& func) {
+    OutputStorageCollector collector;
+    collector.VisitExpr(func);
+    return std::move(collector.output_storages_);
+  }
+
+ private:
+  void VisitExpr_(const SeqExprNode* seq_expr) final {
+    auto output_vars = FreeVars(seq_expr->body);
+    for (const auto& var : output_vars) {
+      output_vars_.insert(var.get());
+    }
+    // Visit the blocks in reverse order for backward propagation
+    for (auto it = seq_expr->blocks.rbegin(); it != seq_expr->blocks.rend(); 
++it) {
+      VisitBindingBlock(*it);
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* call) 
final {
+    static const auto& mem_alloc_tensor_op = 
Op::Get("relax.memory.alloc_tensor");
+    if (output_vars_.count(binding->var.get()) && 
call->op.same_as(mem_alloc_tensor_op)) {
+      output_storages_.insert(call->args[0].as<VarNode>());
+    }
+  }
+
+  void VisitBindingBlock_(const BindingBlockNode* binding_block) override {
+    // Visit the bindings in reverse order
+    for (auto it = binding_block->bindings.rbegin(); it != 
binding_block->bindings.rend(); ++it) {
+      VisitBinding(*it);
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final {
+    if (output_vars_.count(binding->var.get())) {
+      output_vars_.insert(var);
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) 
final {
+    if (output_vars_.count(binding->var.get())) {
+      for (const auto& field : tuple->fields) {
+        output_vars_.insert(field.as<VarNode>());
+      }
+    }
+  }
+
+  std::unordered_set<const VarNode*> output_storages_;
+  std::unordered_set<const VarNode*> output_vars_;
+};
+
 /*!
  * \brief The planner for rewriting the function to enable cuda graph 
capturing.
  */
 class CUDAGraphRewritePlanner : public ExprVisitor {
  public:
-  explicit CUDAGraphRewritePlanner(const IRModule& mod) : mod_(mod) {}
-  std::vector<LiftedFunctionRewritePlan> Plan() {
+  explicit CUDAGraphRewritePlanner(const IRModule& mod, support::Arena* arena)
+      : mod_(mod), arena_(arena) {}
+  std::pair<std::vector<LiftedFunctionRewritePlan*>, 
std::vector<LiftedFunctionRewritePlan*>>
+  Plan() {
     for (const auto& pair : mod_->functions) {
       if (pair.second->IsInstance<FunctionNode>()) {
         // If a function has the num_input attribute, the last 
func->params.size() - num_inputs
@@ -188,41 +246,41 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
           }
         }
         CollectSymbolicVarHints(func);
+        disabled_storage_vars_ = OutputStorageCollector::Collect(func);
         VisitExpr(func);
       }
     }
-    std::vector<LiftedFunctionRewritePlan> plans;
-
-    auto region_to_plan = [&](FuncBuilder* region, bool is_alloc) -> 
LiftedFunctionRewritePlan {
-      LiftedFunctionRewritePlan plan;
-      plan.is_alloc = true;
-      plan.func = region->Build();
+    auto region_to_plan = [&](FuncBuilder* region, bool is_alloc) -> 
LiftedFunctionRewritePlan* {
+      auto* plan = arena_->make<LiftedFunctionRewritePlan>();
+      plan->is_alloc = true;
+      plan->func = region->Build();
       ICHECK(region->size());
-      plan.launch_point = region->bindings_.front()->var.get();
-      plan.is_alloc = is_alloc;
-      for (const auto* binding : region->bindings_) {
-        plan.lifted_bindings.insert(binding->var.get());
-      }
+      plan->launch_point = region->bindings_.front()->var.get();
+      plan->is_alloc = is_alloc;
+      plan->lifted_bindings = std::move(region->bindings_);
       if (region->shape_expr_inputs_.size()) {
         Array<PrimExpr> tir_vars;
         for (const auto* var : region->shape_expr_inputs_) {
           tir_vars.push_back(GetRef<PrimExpr>(var));
         }
-        plan.propogated_tir_vars = ShapeExpr(tir_vars);
+        plan->propogated_tir_vars = ShapeExpr(tir_vars);
+      }
+      plan->inputs.assign(region->inputs_.begin(), region->inputs_.end());
+      for (const auto* var : region->outputs_) {
+        plan->outputs[var] = plan->outputs.size();
       }
-      plan.inputs.assign(region->inputs_.begin(), region->inputs_.end());
-      plan.outputs.assign(region->outputs_.begin(), region->outputs_.end());
       return plan;
     };
 
-    for (auto* region : alloc_storages_) {
-      plans.push_back(region_to_plan(region, /*is_alloc=*/true));
-    }
-
-    for (auto* region : captured_regions_) {
-      plans.push_back(region_to_plan(region, /*is_alloc=*/false));
-    }
-    return plans;
+    std::vector<LiftedFunctionRewritePlan*> alloc_plans, capture_plans;
+    alloc_plans.reserve(alloc_storages_.size());
+    capture_plans.reserve(captured_regions_.size());
+    std::transform(alloc_storages_.begin(), alloc_storages_.end(), 
std::back_inserter(alloc_plans),
+                   [&](FuncBuilder* region) { return region_to_plan(region, 
/*is_alloc=*/true); });
+    std::transform(captured_regions_.begin(), captured_regions_.end(),
+                   std::back_inserter(capture_plans),
+                   [&](FuncBuilder* region) { return region_to_plan(region, 
/*is_alloc=*/false); });
+    return {std::move(alloc_plans), std::move(capture_plans)};
   }
 
   /*!
@@ -241,31 +299,36 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
    *\brief Start a new static region. This method should be called when 
encountering a
    * CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on 
static parameters.
    */
-  void StartRegion() { current_.capture_builder = arena_.make<FuncBuilder>(); }
+  void StartRegion() { current_block_scope_.capture_builder = 
arena_->make<FuncBuilder>(); }
 
   /*!
    * \brief Finish a static region. This method should be called when 
non-static bindings or
    * unsupported operations are encountered.
    */
   void EndRegion() {
-    if (current_.capture_builder && current_.capture_builder->size()) {
-      captured_regions_.emplace_back(current_.capture_builder);
+    if (current_block_scope_.capture_builder && 
current_block_scope_.capture_builder->size()) {
+      captured_regions_.emplace_back(current_block_scope_.capture_builder);
     }
-    current_.capture_builder = nullptr;
+    current_block_scope_.capture_builder = nullptr;
+  }
+
+  void VisitExpr_(const FunctionNode* func) final {
+    current_function_scope_.alloc_storage_builder = 
arena_->make<FuncBuilder>();
+    ExprVisitor::VisitExpr_(func);
+    if (current_function_scope_.alloc_storage_builder->outputs_.size()) {
+      
alloc_storages_.emplace_back(current_function_scope_.alloc_storage_builder);
+    }
+    current_function_scope_.alloc_storage_builder = nullptr;
   }
 
   void VisitBindingBlock_(const BindingBlockNode* binding_block) final {
-    Scope new_scope;
-    std::swap(new_scope, current_);
-    current_.alloc_storage_builder = arena_.make<FuncBuilder>();
+    BindingBlockScope new_scope;
+    std::swap(new_scope, current_block_scope_);
     for (const auto& binding : binding_block->bindings) {
       VisitBinding(binding);
     }
     EndRegion();
-    if (current_.alloc_storage_builder->outputs_.size()) {
-      alloc_storages_.emplace_back(current_.alloc_storage_builder);
-    }
-    std::swap(new_scope, current_);
+    std::swap(new_scope, current_block_scope_);
   }
 
   void VisitBinding_(const VarBindingNode* binding, const CallNode* call) 
final {
@@ -273,8 +336,10 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
     static const auto& builtin_alloc_tensor_op = 
Op::Get("relax.builtin.alloc_tensor");
     static const auto& call_builtin_with_ctx_op = 
Op::Get("relax.call_builtin_with_ctx");
 
-    if (call->op.same_as(mem_alloc_storage_op) && 
IsStaticAllocStorage(binding)) {
-      AddStaticBinding(binding, /*is_alloc_storage=*/true);
+    if (call->op.same_as(mem_alloc_storage_op)) {
+      if (IsStaticAllocStorage(binding)) {
+        AddStaticBinding(binding, /*is_alloc_storage=*/true);
+      }
       return;
     } else if (call->op.same_as(builtin_alloc_tensor_op)) {
       return;
@@ -321,7 +386,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
         }
         return false;
       }();
-      if (current_.capture_builder == nullptr && is_kernel_launch) {
+      if (current_block_scope_.capture_builder == nullptr && is_kernel_launch) 
{
         StartRegion();
       }
       AddStaticBinding(binding, /*is_alloc_storage=*/false);
@@ -335,24 +400,24 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
 
   void MarkAsFuncInput(const std::vector<const VarNode*>& vars,
                        const std::vector<const tir::VarNode*>& tir_vars = {}) {
-    if (current_.capture_builder == nullptr) {
+    if (current_block_scope_.capture_builder == nullptr) {
       return;
     }
     for (const VarNode* var : vars) {
       auto it = binding_to_region_.find(var);
-      if (it == binding_to_region_.end() || it->second != 
current_.capture_builder) {
-        current_.capture_builder->MarkInput(var);
+      if (it == binding_to_region_.end() || it->second != 
current_block_scope_.capture_builder) {
+        current_block_scope_.capture_builder->MarkInput(var);
       }
     }
     for (const tir::VarNode* tir_var : tir_vars) {
-      current_.capture_builder->MarkShapeExprInput(tir_var);
+      current_block_scope_.capture_builder->MarkShapeExprInput(tir_var);
     }
   }
 
   void MarkAsFuncOutput(const std::vector<const VarNode*>& vars) {
     for (const VarNode* var : vars) {
       if (auto it = binding_to_region_.find(var);
-          it != binding_to_region_.end() && it->second != 
current_.capture_builder) {
+          it != binding_to_region_.end() && it->second != 
current_block_scope_.capture_builder) {
         it->second->MarkOutput(var);
       }
     }
@@ -476,6 +541,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
 
  private:
   bool IsStaticAllocStorage(const VarBindingNode* binding) {
+    if (disabled_storage_vars_.count(binding->var.get())) {
+      return false;
+    }
     // Check if the allocation has constant shape
     const auto* alloc_storage_call = binding->value.as<CallNode>();
     auto shape = Downcast<ShapeExpr>(alloc_storage_call->args[0]);
@@ -491,33 +559,41 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
    */
   void AddStaticBinding(const VarBindingNode* binding, bool is_alloc_storage) {
     if (is_alloc_storage) {
-      current_.alloc_storage_builder->AddBinding(binding);
-      binding_to_region_[binding->var.get()] = current_.alloc_storage_builder;
-    } else if (current_.capture_builder != nullptr) {
+      current_function_scope_.alloc_storage_builder->AddBinding(binding);
+      binding_to_region_[binding->var.get()] = 
current_function_scope_.alloc_storage_builder;
+    } else if (current_block_scope_.capture_builder != nullptr) {
       // Add the binding if the capture builder exists. It is possible that 
capture builder is
       // null when it is not capturing. This is the case that there are not 
yet any kernel launches
       // encountered, in this case static bindings (e.g. binding of other 
non-kernel-launch
       // operations) are marked but are not lifted.
-      current_.capture_builder->AddBinding(binding);
-      binding_to_region_[binding->var.get()] = current_.capture_builder;
+      current_block_scope_.capture_builder->AddBinding(binding);
+      binding_to_region_[binding->var.get()] = 
current_block_scope_.capture_builder;
     }
     static_vars_.emplace(binding->var.get());
   }
 
-  /*! \brief The states of the current scope (the BindingBlock) which is a 
pair of FuncBuilder.
+  /*! \brief The states of the current scope (the BindingBlock) which is a 
FuncBuilder.
    * The FuncBuilder are initialized with nullptr, meaning the planner is 
currently not doing any
    * lifting. They are initialized lazily when a binding that can be lifted is 
encountered.
    * They are reset to nullptr when an unsupported operation is encountered.
    */
-  struct Scope {
+  struct BindingBlockScope {
+    FuncBuilder* capture_builder = nullptr;  // The builder for the capture 
function
+  };
+
+  /*! \brief The states of the current function scope which is a FuncBuilder 
to build the storage
+   * allocation function.
+   */
+  struct FunctionScope {
     FuncBuilder* alloc_storage_builder = nullptr;  // The builder for the 
allocation function
-    FuncBuilder* capture_builder = nullptr;        // The builder for the 
capture function
   };
 
   // The IRModule
   IRModule mod_;
-  // States of the current scope
-  Scope current_;
+  // States of the current block scope
+  BindingBlockScope current_block_scope_;
+  // States of the current function scope
+  FunctionScope current_function_scope_;
   // Variables whose buffer address is fixed
   std::unordered_set<const VarNode*> static_vars_;
   // The name of the variables that are allowed to be symbolic
@@ -529,64 +605,183 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
   std::vector<FuncBuilder*> captured_regions_;
   // The regions for allocation.
   std::vector<FuncBuilder*> alloc_storages_;
+  // The binding variables that are not allowed to be captured.
+  std::unordered_set<const VarNode*> disabled_storage_vars_;
   // The arena.
-  support::Arena arena_;
+  support::Arena* arena_;
 };
 
+/*!
+ * \brief Merge storage allocations from different functions by reusing the 
largest allocation that
+ * can be shared among all the functions. The original rewriting plans are 
updated in-place to use
+ * the merged storage allocations.
+ *
+ * When multiple functions are rewritten to be executed with CUDA graph, the 
storage allocations
+ * from different functions can be reused. This functions merge multiple 
storage allocations
+ * functions to a single function that allocates the sufficiently large 
storage to be shared among
+ * all the functions.
+ *
+ * \param alloc_plans The allocation plans of the functions to be merged.
+ * \return The new allocation function that merges the storage allocations.
+ */
+Function MergeAllocationPlans(const std::vector<LiftedFunctionRewritePlan*>& 
alloc_plans) {
+  // The storage record that contains the size of the storage allocation and 
the binding of the
+  // storage allocation.
+  struct StorageRecord {
+    // The size of the storage object in bytes
+    int64_t size;
+    // The binding of the storage allocation
+    const VarBindingNode* binding;
+    // The source rewriting plan that the storage record is from
+    LiftedFunctionRewritePlan* src;
+
+    bool operator<(const StorageRecord& other) const { return size < 
other.size; }
+  };
+  // Using an (ordered) map to make sure the result is deterministic
+  std::map<String, std::vector<std::vector<StorageRecord>>> storage_records;
+  static const auto& mem_alloc_storage_op = 
Op::Get("relax.memory.alloc_storage");
+
+  // Collect the storage records for each storage scope. Storage records are 
stored separately
+  // for each original function.
+  for (int plan_id = 0; plan_id < static_cast<int>(alloc_plans.size()); 
++plan_id) {
+    LiftedFunctionRewritePlan* plan = alloc_plans[plan_id];
+    ICHECK(plan->is_alloc);
+    for (const VarBindingNode* binding : plan->lifted_bindings) {
+      // Extract the stroage record from the Call expr.
+      Call alloc_storage = Downcast<Call>(binding->value);
+      ICHECK(alloc_storage->op.same_as(mem_alloc_storage_op));
+      auto storage_shape = Downcast<ShapeExpr>(alloc_storage->args[0]);
+      ICHECK_EQ(storage_shape->values.size(), 1);
+      int64_t size = Downcast<IntImm>(storage_shape->values[0])->value;
+      int64_t virtual_device_id =
+          
Downcast<IntImm>(Downcast<PrimValue>(alloc_storage->args[1])->value)->value;
+      ICHECK_EQ(virtual_device_id, 0);
+      String storage_scope = 
Downcast<StringImm>(alloc_storage->args[2])->value;
+      auto [it, _] = storage_records.try_emplace(storage_scope, 
alloc_plans.size());
+      it->second[plan_id].emplace_back(StorageRecord{size, binding, plan});
+    }
+  }
+
+  // Merge the storage records within each storage scope.
+  // This is achieved by sorting the storage records in descending order of 
size and then merging
+  // storage allocations from different functions to the largest allocation 
that can be shared
+  // among all the functions.
+  // This assumes that multiple functions will not run concurrently.
+  std::vector<const VarBindingNode*> merged_allocs;
+  // Merge the storage records within each storage scope.
+  for (auto& [storage_scope, curr_scope_records] : storage_records) {
+    // The number of storages needed for the current storage scope, which is 
the maximum number of
+    // storage records among all the functions.
+    int num_storages = 0;
+    for (auto& records_of_plan : curr_scope_records) {
+      // Sort descending by size, preserve the original order if the sizes are 
equal.
+      std::stable_sort(records_of_plan.rbegin(), records_of_plan.rend());
+      num_storages = std::max(num_storages, 
static_cast<int>(records_of_plan.size()));
+    }
+    // The iterators to scan the storage records of all functions from the 
left to the right
+    // at the same time.
+    std::vector<int> iters(alloc_plans.size(), 0);
+    for (int i = 0; i < num_storages; i++) {
+      // The storage records from different functions that can be merged to 
the same storage.
+      std::vector<StorageRecord> to_merge;
+      for (int plan_index = 0; plan_index < 
static_cast<int>(curr_scope_records.size());
+           plan_index++) {
+        if (iters[plan_index] < 
static_cast<int>(curr_scope_records[plan_index].size())) {
+          
to_merge.push_back(curr_scope_records[plan_index][iters[plan_index]++]);
+        }
+      }
+      const StorageRecord& largest_storage =
+          *std::max_element(to_merge.begin(), to_merge.end(),
+                            [](const auto& lhs, const auto& rhs) { return lhs 
< rhs; });
+      // Merge the records to the largest allocation by updating the index of 
the output element
+      // to that of the new allocation function.
+      int storage_index = static_cast<int>(merged_allocs.size());
+      for (const StorageRecord& rec : to_merge) {
+        auto* plan = rec.src;
+        plan->outputs.at(rec.binding->var.get()) = storage_index;
+      }
+      merged_allocs.push_back(largest_storage.binding);
+    }
+  }
+  // Create the new allocation function for the merged allocations.
+  FuncBuilder builder;
+  for (const auto* binding : merged_allocs) {
+    builder.AddBinding(binding);
+    builder.MarkOutput(binding->var.get());
+  }
+  return builder.Build();
+}
+
 /*! \brief The rewriter for CUDA graph */
 class CUDAGraphRewriter : public ExprMutator {
  public:
   explicit CUDAGraphRewriter(const IRModule& mod) : ExprMutator(mod) {}
 
   IRModule Rewrite() {
-    CUDAGraphRewritePlanner planner(builder_->GetContextIRModule());
-    auto plans = planner.Plan();
-    for (const auto& plan : plans) {
-      subgraph_launches_[plan.launch_point] = plan;
-    }
+    CUDAGraphRewritePlanner planner(builder_->GetContextIRModule(), &arena_);
 
+    // Collect the target functions for rewriting before any mutation.
+    std::vector<std::pair<GlobalVar, Function>> target_functions;
     for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
       if (func->IsInstance<FunctionNode>()) {
-        auto new_func = Downcast<Function>(VisitExpr(func));
-        if (!new_func.same_as(func)) {
-          builder_->UpdateFunction(gv, new_func);
-        }
+        target_functions.emplace_back(gv, Downcast<Function>(func));
+      }
+    }
+
+    auto [alloc_plans, capture_plans] = planner.Plan();
+    if (alloc_plans.size()) {
+      auto global_alloc_func = MergeAllocationPlans(alloc_plans);
+      gv_global_alloc_ = builder_->AddFunction(global_alloc_func, 
"cuda_graph_alloc");
+    }
+    for (const auto* plan : alloc_plans) {
+      subgraph_launches_[plan->launch_point] = plan;
+    }
+    for (const auto* plan : capture_plans) {
+      subgraph_launches_[plan->launch_point] = plan;
+    }
+
+    for (const auto& [gv, func] : target_functions) {
+      current_func_ = gv;
+      auto new_func = Downcast<Function>(VisitExpr(func));
+      if (!new_func.same_as(func)) {
+        builder_->UpdateFunction(gv, new_func);
       }
     }
     return builder_->GetContextIRModule();
   }
 
-  void LaunchSubgraph(const VarBindingNode* op, const 
LiftedFunctionRewritePlan& plan) {
+  void LaunchSubgraph(const VarBindingNode* op, const 
LiftedFunctionRewritePlan* plan) {
     static const auto& call_builtin_with_ctx_op = 
Op::Get("relax.call_builtin_with_ctx");
     static const auto& builtin_run_or_capture = 
ExternFunc("vm.builtin.cuda_graph.run_or_capture");
     static const auto& builtin_get_cached_alloc =
         ExternFunc("vm.builtin.cuda_graph.get_cached_alloc");
 
     Expr launch_subgraph;
-    auto gv_func =
-        builder_->AddFunction(plan.func, plan.is_alloc ? "cuda_graph_alloc" : 
"cuda_graph_capture");
-    if (plan.is_alloc) {
+    if (plan->is_alloc) {
       // Storage allocation should be fully static and shouldn't depend on any 
symbolic variables.
-      ICHECK(!plan.propogated_tir_vars.defined());
-      ICHECK(plan.inputs.empty());
-      launch_subgraph =
-          Call(call_builtin_with_ctx_op,
-               {builtin_get_cached_alloc,
-                Tuple({gv_func, PrimValue(IntImm(DataType::Int(64), 
index_alloc_++))})},
-               Attrs(), {plan.func->ret_struct_info});
+      ICHECK(!plan->propogated_tir_vars.defined());
+      ICHECK(plan->inputs.empty());
+      auto gv_alloc = gv_global_alloc_.value();
+      auto ret_struct_info = 
Downcast<FuncStructInfo>(gv_alloc->struct_info_.value())->ret;
+      launch_subgraph = Call(
+          call_builtin_with_ctx_op,
+          {builtin_get_cached_alloc, Tuple({gv_alloc, 
PrimValue(IntImm(DataType::Int(64), 0))})},
+          Attrs(), {ret_struct_info});
     } else {
-      StructInfo call_sinfo = plan.func->ret_struct_info;
+      auto gv_func = builder_->AddFunction(
+          plan->func, current_func_.value()->name_hint + 
"_cuda_graph_capture");
+      StructInfo call_sinfo = plan->func->ret_struct_info;
       // Arguments of the lifted function
       Array<Expr> args;
-      for (const auto& arg : plan.inputs) {
+      for (const auto& arg : plan->inputs) {
         args.push_back(VisitExpr_(arg));
       }
-      if (plan.propogated_tir_vars.defined()) {
-        ShapeExpr propogated_tir_vars = plan.propogated_tir_vars.value();
+      if (plan->propogated_tir_vars.defined()) {
+        ShapeExpr propogated_tir_vars = plan->propogated_tir_vars.value();
         args.push_back(propogated_tir_vars);
         // The ret_struct_info of the lifted function can contain symbolic 
variables. We need to
         // bind the symbolic parameters to the actual values.
-        const auto& shape_expr = plan.func->params.back();
+        const auto& shape_expr = plan->func->params.back();
         auto symbolic_params =
             
Downcast<ShapeStructInfo>(shape_expr->struct_info_.value())->values.value();
         Map<tir::Var, PrimExpr> tir_var_remap;
@@ -599,25 +794,23 @@ class CUDAGraphRewriter : public ExprMutator {
       // Arguments of builtin_run_or_capture
       Array<Expr> tuple_arg_fields{gv_func, Tuple(args),
                                    PrimValue(IntImm(DataType::Int(64), 
index_capture_++))};
-      if (plan.propogated_tir_vars.defined()) {
+      if (plan->propogated_tir_vars.defined()) {
         // The shape expr is explicitly passed twice, one as the last argument 
of the lifted
         // function, one as the last argument of builtin_run_or_capture as the 
cache key. Explicitly
         // passing it twice simplifies the handling during the capture phase.
-        tuple_arg_fields.push_back(plan.propogated_tir_vars.value());
+        tuple_arg_fields.push_back(plan->propogated_tir_vars.value());
       }
       launch_subgraph =
           Call(call_builtin_with_ctx_op, {builtin_run_or_capture, 
Tuple(tuple_arg_fields)}, Attrs(),
                {call_sinfo});
     }
     Expr ret_value = builder_->Emit(launch_subgraph);
-    for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) {
-      // The unpacked result is saved in the var_redef_. It will be emitted 
when 1) the var
-      // definition is the original IR is visited, or 2) the var is used as an 
input to another
-      // lifted function, whichever comes first.
-      var_redef_[plan.outputs[i]] = TupleGetItem(ret_value, i);
+    for (const auto& [var, tuple_index] : plan->outputs) {
+      var_redef_[var] = TupleGetItem(ret_value, tuple_index);
     }
-
-    lifted_bindings_.insert(plan.lifted_bindings.begin(), 
plan.lifted_bindings.end());
+    std::transform(plan->lifted_bindings.begin(), plan->lifted_bindings.end(),
+                   std::inserter(lifted_binding_vars_, 
lifted_binding_vars_.end()),
+                   [](const BindingNode* binding) { return binding->var.get(); 
});
   }
 
   void VisitBinding_(const VarBindingNode* op) final {
@@ -629,7 +822,7 @@ class CUDAGraphRewriter : public ExprMutator {
       EmitRedef(op->var.get(), it->second);
       return;
     }
-    if (lifted_bindings_.count(op->var.get())) {
+    if (lifted_binding_vars_.count(op->var.get())) {
       // The binding is lifted to the subgraph and will be removed from the 
original function.
       return;
     }
@@ -654,11 +847,14 @@ class CUDAGraphRewriter : public ExprMutator {
     return new_var;
   }
 
-  std::unordered_map<const VarNode*, LiftedFunctionRewritePlan> 
subgraph_launches_;
+  std::unordered_map<const VarNode*, const LiftedFunctionRewritePlan*> 
subgraph_launches_;
   std::unordered_map<const VarNode*, Expr> var_redef_;
-  std::unordered_set<const VarNode*> lifted_bindings_;
+  std::unordered_set<const VarNode*> lifted_binding_vars_;
   int index_alloc_ = 0;
   int index_capture_ = 0;
+  support::Arena arena_;
+  Optional<GlobalVar> gv_global_alloc_ = NullOpt;
+  Optional<GlobalVar> current_func_ = NullOpt;
 };
 
 IRModule RewriteCUDAGraph(IRModule mod) {
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py 
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 43b26f110f..9db285fea6 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -107,7 +107,7 @@ def test_rewrite_cuda_graph():
             return gv
 
         @R.function(private=True)
-        def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: 
R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
+        def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: 
R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
             R.func_attr({"relax.force_pure": True})
             cls = Expected
             _2: R.Tuple = cls.exp(alloc, alloc1)
@@ -133,7 +133,7 @@ def test_rewrite_cuda_graph():
             storage1: R.Object = gv[1]
             alloc1: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
             storage2: R.Object = gv[2]
-            gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.cuda_graph_capture, (alloc, alloc1, storage, storage2), R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
+            gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.main_cuda_graph_capture, (alloc, alloc1, storage, storage2), 
R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
             alloc3: R.Tensor((2, 4), dtype="float32") = gv1[0]
             alloc4: R.Tensor((2, 4), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
             _6: R.Tuple = cls.exp(alloc3, alloc4)
@@ -191,7 +191,7 @@ def test_tuple():
             _5: R.Tuple = R.memory.kill_tensor(alloc2)
             _6: R.Tuple = R.memory.kill_storage(storage)
             _7: R.Tuple = R.memory.kill_storage(storage1)
-            return alloc2
+            return alloc3
 
     @I.ir_module
     class Expected:
@@ -217,7 +217,7 @@ def test_tuple():
             return gv
 
         @R.function(private=True)
-        def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> 
R.Tuple(R.Tensor((2, 4), dtype="float32")):
+        def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> 
R.Tuple(R.Tensor((2, 4), dtype="float32")):
             R.func_attr({"relax.force_pure": True})
             cls = Expected
             _: R.Tuple = cls.exp(alloc, alloc1)
@@ -242,14 +242,14 @@ def test_tuple():
             _: R.Tuple = cls.exp(x, alloc)
             storage1: R.Object = gv[1]
             alloc1: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
-            gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
+            gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.main_cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
             alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0]
             alloc3: R.Tensor((2, 4), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
             _4: R.Tuple = cls.exp(alloc2, alloc3)
             _5: R.Tuple = R.memory.kill_tensor(alloc2)
             _6: R.Tuple = R.memory.kill_storage(storage)
             _7: R.Tuple = R.memory.kill_storage(storage1)
-            return alloc2
+            return alloc3
     # fmt: on
 
     after = relax.transform.RewriteCUDAGraph()(Before)
@@ -318,7 +318,7 @@ def test_vm_builtin():
             return gv
 
         @R.function(private=True)
-        def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> 
R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")):
+        def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> 
R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")):
             R.func_attr({"relax.force_pure": True})
             cls = Expected
             _2: R.Tuple = cls.exp(alloc, alloc1)
@@ -338,7 +338,7 @@ def test_vm_builtin():
             _1: R.Tuple = cls.exp(x, alloc)
             storage1: R.Object = gv[1]
             alloc1: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
-            gv1: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), 
dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), 
dtype="float32")),))
+            gv1: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), 
dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.main_cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), 
dtype="float32")),))
             alloc2: R.Tensor((2, 4), dtype="float32") = gv1[1]
             lv: R.Tensor((2, 4), dtype="float32") = gv1[0]
             _4: R.Tuple = R.call_packed("vm.builtin.dummy", (x, lv), 
sinfo_args=(R.Tuple,))
@@ -528,7 +528,7 @@ def test_capture_fixed_inputs():
             return gv
 
         @R.function(private=True)
-        def cuda_graph_capture(
+        def main_cuda_graph_capture(
             lv: R.Tensor((16, 32, 32, 16), dtype="float16"),
             lv1: R.Tensor((16, 3, 3, 16), dtype="float16"),
             alloc1: R.Tensor((16, 32, 32, 16), dtype="float16"),
@@ -635,7 +635,7 @@ def test_capture_fixed_inputs():
             ) = R.call_builtin_with_ctx(
                 "vm.builtin.cuda_graph.run_or_capture",
                 (
-                    cls.cuda_graph_capture,
+                    cls.main_cuda_graph_capture,
                     (lv_1, lv1, alloc1, alloc, params, storage),
                     R.prim_value(0),
                 ),
@@ -728,7 +728,7 @@ def test_static_args():
             return gv
 
         @R.function(private=True)
-        def cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) -> 
R.Tuple:
+        def main_cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) 
-> R.Tuple:
             R.func_attr({"relax.force_pure": True})
             _: R.Object = R.call_packed("dummy_func", alloc0, 
R.dtype("float32"), R.str("string"))
             gv: R.Tuple = R.tuple()
@@ -748,7 +748,7 @@ def test_static_args():
             )
             gv1: R.Tuple = R.call_builtin_with_ctx(
                 "vm.builtin.cuda_graph.run_or_capture",
-                (cls.cuda_graph_capture, (alloc0,), R.prim_value(0)),
+                (cls.main_cuda_graph_capture, (alloc0,), R.prim_value(0)),
                 sinfo_args=(R.Tuple,),
             )
             return R.tuple()
@@ -822,7 +822,7 @@ def test_dynamic_capture():
             return gv
 
         @R.function(private=True)
-        def cuda_graph_capture(
+        def main_cuda_graph_capture(
             alloc1: R.Tensor(("m",), dtype="float32"),
             alloc2: R.Tensor(("m",), dtype="float32"),
             shape_expr: R.Shape(["m"]),
@@ -858,7 +858,7 @@ def test_dynamic_capture():
             R.call_builtin_with_ctx(
                 "vm.builtin.cuda_graph.run_or_capture",
                 (
-                    cls.cuda_graph_capture,
+                    cls.main_cuda_graph_capture,
                     (alloc1, alloc2, R.shape([m])),
                     R.prim_value(0),
                     R.shape([m]),
@@ -875,5 +875,218 @@ def test_dynamic_capture():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+class TestMergeAllocFuncs(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function
+        def func1():
+            R.func_attr({"relax.force_pure": True})
+            storage1 = R.memory.alloc_storage(R.shape([128]), 0, "global", 
"float32")
+            storage2 = R.memory.alloc_storage(R.shape([256]), 0, "global", 
"float32")
+            storage3 = R.memory.alloc_storage(R.shape([512]), 0, "ipc_memory", 
"float32")
+            alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([128]), 
"float32")
+            alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([256]), 
"float32")
+            alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([512]), 
"float32")
+            R.call_packed("dummy", alloc1, alloc2, alloc3, 
sinfo_args=(R.Tuple,))
+            return R.tuple()
+
+        @R.function
+        def func2():
+            R.func_attr({"relax.force_pure": True})
+            storage1 = R.memory.alloc_storage(R.shape([192]), 0, "global", 
"float32")
+            storage2 = R.memory.alloc_storage(R.shape([64]), 0, "global", 
"float32")
+            storage3 = R.memory.alloc_storage(R.shape([1024]), 0, 
"ipc_memory", "float32")
+            storage4 = R.memory.alloc_storage(R.shape([512]), 0, "global", 
"float32")
+            alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([192]), 
"float32")
+            alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([64]), 
"float32")
+            alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([1024]), 
"float32")
+            alloc4 = R.memory.alloc_tensor(storage4, 0, R.shape([512]), 
"float32")
+            R.call_packed("dummy", alloc1, alloc2, alloc3, alloc4, 
sinfo_args=(R.Tuple,))
+            return R.tuple()
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object, 
R.Object):
+            R.func_attr({"relax.force_pure": True})
+            storage4: R.Object = R.memory.alloc_storage(
+                R.shape([512]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            storage1: R.Object = R.memory.alloc_storage(
+                R.shape([192]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            storage2: R.Object = R.memory.alloc_storage(
+                R.shape([64]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            storage3: R.Object = R.memory.alloc_storage(
+                R.shape([1024]), R.prim_value(0), R.str("ipc_memory"), 
R.dtype("float32")
+            )
+            gv: R.Tuple(R.Object, R.Object, R.Object, R.Object) = (
+                storage4,
+                storage1,
+                storage2,
+                storage3,
+            )
+            return gv
+
+        @R.function
+        def func1() -> R.Tuple:
+            R.func_attr({"relax.force_pure": True})
+            cls = Expected
+            gv: R.Tuple(R.Object, R.Object, R.Object, R.Object) = 
R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.get_cached_alloc",
+                (cls.cuda_graph_alloc, R.prim_value(0)),
+                sinfo_args=(R.Tuple(R.Object, R.Object, R.Object, R.Object),),
+            )
+            storage1: R.Object = gv[1]
+            storage2: R.Object = gv[0]
+            storage3: R.Object = gv[3]
+            alloc1: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor(
+                storage1, R.prim_value(0), R.shape([128]), R.dtype("float32")
+            )
+            alloc2: R.Tensor((256,), dtype="float32") = R.memory.alloc_tensor(
+                storage2, R.prim_value(0), R.shape([256]), R.dtype("float32")
+            )
+            alloc3: R.Tensor((512,), dtype="float32") = R.memory.alloc_tensor(
+                storage3, R.prim_value(0), R.shape([512]), R.dtype("float32")
+            )
+            R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.run_or_capture",
+                (cls.func1_cuda_graph_capture, (alloc1, alloc2, alloc3), 
R.prim_value(0)),
+                sinfo_args=(R.Tuple,),
+            )
+            return R.tuple()
+
+        @R.function(private=True)
+        def func1_cuda_graph_capture(
+            alloc1: R.Tensor((128,), dtype="float32"),
+            alloc2: R.Tensor((256,), dtype="float32"),
+            alloc3: R.Tensor((512,), dtype="float32"),
+        ) -> R.Tuple:
+            R.func_attr({"relax.force_pure": True})
+            R.call_packed("dummy", alloc1, alloc2, alloc3, 
sinfo_args=(R.Tuple,))
+            R.tuple()
+            return R.tuple()
+
+        @R.function
+        def func2() -> R.Tuple:
+            R.func_attr({"relax.force_pure": True})
+            cls = Expected
+            gv2: R.Tuple(R.Object, R.Object, R.Object, R.Object) = 
R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.get_cached_alloc",
+                (cls.cuda_graph_alloc, R.prim_value(0)),
+                sinfo_args=(R.Tuple(R.Object, R.Object, R.Object, R.Object),),
+            )
+            storage11: R.Object = gv2[1]
+            storage21: R.Object = gv2[2]
+            storage31: R.Object = gv2[3]
+            storage4: R.Object = gv2[0]
+            alloc1: R.Tensor((192,), dtype="float32") = R.memory.alloc_tensor(
+                storage11, R.prim_value(0), R.shape([192]), R.dtype("float32")
+            )
+            alloc2: R.Tensor((64,), dtype="float32") = R.memory.alloc_tensor(
+                storage21, R.prim_value(0), R.shape([64]), R.dtype("float32")
+            )
+            alloc3: R.Tensor((1024,), dtype="float32") = R.memory.alloc_tensor(
+                storage31, R.prim_value(0), R.shape([1024]), R.dtype("float32")
+            )
+            alloc4: R.Tensor((512,), dtype="float32") = R.memory.alloc_tensor(
+                storage4, R.prim_value(0), R.shape([512]), R.dtype("float32")
+            )
+            R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.run_or_capture",
+                (cls.func2_cuda_graph_capture, (alloc1, alloc2, alloc3, 
alloc4), R.prim_value(1)),
+                sinfo_args=(R.Tuple,),
+            )
+            return R.tuple()
+
+        @R.function(private=True)
+        def func2_cuda_graph_capture(
+            alloc1: R.Tensor((192,), dtype="float32"),
+            alloc2: R.Tensor((64,), dtype="float32"),
+            alloc3: R.Tensor((1024,), dtype="float32"),
+            alloc4: R.Tensor((512,), dtype="float32"),
+        ) -> R.Tuple:
+            R.func_attr({"relax.force_pure": True})
+            R.call_packed("dummy", alloc1, alloc2, alloc3, alloc4, 
sinfo_args=(R.Tuple,))
+            R.tuple()
+            return R.tuple()
+
+
+class TestDisableCaptureOutput(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((8,), "float32")) -> R.Tuple(R.Tensor((8,), 
"float32")):
+            R.func_attr({"relax.force_pure": True})
+            storage1 = R.memory.alloc_storage(R.shape([8]), 0, "global", 
"float32")
+            alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([8]), 
"float32")
+            _ = R.call_packed("dummy", x, alloc1, sinfo_args=(R.Tuple,))
+            storage2 = R.memory.alloc_storage(R.shape([8]), 0, "global", 
"float32")
+            alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([8]), 
"float32")
+            _1 = R.call_packed("dummy", alloc1, alloc2, sinfo_args=(R.Tuple,))
+            storage3 = R.memory.alloc_storage(R.shape([8]), 0, "global", 
"float32")
+            alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([8]), 
"float32")
+            _2 = R.call_packed("dummy", alloc2, alloc3, sinfo_args=(R.Tuple,))
+            gv = (alloc3,)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function(private=True)
+        def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+            R.func_attr({"relax.force_pure": True})
+            storage1: R.Object = R.memory.alloc_storage(
+                R.shape([8]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            storage2: R.Object = R.memory.alloc_storage(
+                R.shape([8]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            gv: R.Tuple(R.Object, R.Object) = storage1, storage2
+            return gv
+
+        @R.function(private=True)
+        def main_cuda_graph_capture(
+            alloc1: R.Tensor((8,), dtype="float32"), alloc2: R.Tensor((8,), 
dtype="float32")
+        ) -> R.Tuple:
+            R.func_attr({"relax.force_pure": True})
+            R.call_packed("dummy", alloc1, alloc2, sinfo_args=(R.Tuple,))
+            R.tuple()
+            return R.tuple()
+
+        @R.function
+        def main(x: R.Tensor((8,), dtype="float32")) -> R.Tuple(R.Tensor((8,), 
dtype="float32")):
+            R.func_attr({"relax.force_pure": True})
+            cls = Expected
+            gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.get_cached_alloc",
+                (cls.cuda_graph_alloc, R.prim_value(0)),
+                sinfo_args=(R.Tuple(R.Object, R.Object),),
+            )
+            storage1: R.Object = gv[0]
+            alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
+                storage1, R.prim_value(0), R.shape([8]), R.dtype("float32")
+            )
+            R.call_packed("dummy", x, alloc1, sinfo_args=(R.Tuple,))
+            storage2: R.Object = gv[1]
+            alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
+                storage2, R.prim_value(0), R.shape([8]), R.dtype("float32")
+            )
+            R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.run_or_capture",
+                (cls.main_cuda_graph_capture, (alloc1, alloc2), 
R.prim_value(0)),
+                sinfo_args=(R.Tuple,),
+            )
+            storage3: R.Object = R.memory.alloc_storage(
+                R.shape([8]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            alloc3: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
+                storage3, R.prim_value(0), R.shape([8]), R.dtype("float32")
+            )
+            R.call_packed("dummy", alloc2, alloc3, sinfo_args=(R.Tuple,))
+            gv = (alloc3,)
+            return gv
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to