This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ef32a611e3 [Relax] Enable capturing symbolic shapes in cuda graph 
(#16815)
ef32a611e3 is described below

commit ef32a611e386251c86fa255db4b8530b291dde11
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Mar 30 14:31:22 2024 -0700

    [Relax] Enable capturing symbolic shapes in cuda graph (#16815)
    
    * [Relax] Enable capturing symbolic shapes in cuda graph
    
    * Add Bind sinfo util
    
    * Bind ret sinfo
    
    * address comments
    
    * add comments
    
    * fix
    
    * update test
---
 include/tvm/relax/utils.h                          |   7 +
 src/relax/transform/rewrite_cuda_graph.cc          | 161 ++++++++++++++++++---
 src/relax/utils.cc                                 |   4 +
 src/runtime/relax_vm/cuda/cuda_graph_builtin.cc    |  62 ++++++--
 .../relax/test_transform_rewrite_cuda_graph.py     | 118 +++++++++++++++
 5 files changed, 321 insertions(+), 31 deletions(-)

diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index 74e773abe7..e48c1856f9 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -50,6 +50,13 @@ namespace relax {
 TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
                   const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map = {});
 
+/*!
+ * \brief Bind the symbolic variables to a StructInfo. This is a helper 
function usually called by
+ * other pass functions to help optimizations.
+ */
+TVM_DLL StructInfo Bind(const StructInfo& sinfo,
+                        const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map);
+
 /*!
  * \brief Infer a binding map for symbolic variables
  *
diff --git a/src/relax/transform/rewrite_cuda_graph.cc 
b/src/relax/transform/rewrite_cuda_graph.cc
index b67a638dd6..25b229ebce 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -53,6 +53,8 @@
 #include <tvm/relax/backend.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
 
 #include "../../support/arena.h"
 #include "../../support/ordered_set.h"
@@ -82,6 +84,8 @@ struct LiftedFunctionRewritePlan {
   std::vector<const VarNode*> outputs;
   // The corresponding binding vars in the original function of the inputs of 
the lifted function
   std::vector<const VarNode*> inputs;
+  // The tir vars in the original function that are propagated to the lifted 
function
+  Optional<ShapeExpr> propogated_tir_vars = NullOpt;
 };
 
 /*! \brief Builder of the lifted function for cuda graph capturing or 
allocations */
@@ -98,6 +102,11 @@ class FuncBuilder : public ExprMutator {
    * \param var The variable to mark as input
    */
   void MarkInput(const VarNode* var) { inputs_.push_back(var); }
+  /*!
+   * \brief Mark a TIR variable as the ShapeExpr input of the new function.
+   * \param var The variable to mark as input
+   */
+  void MarkShapeExprInput(const tir::VarNode* var) { 
shape_expr_inputs_.push_back(var); }
   /*!
    * \brief Mark a variable as the output of the new function. The variable 
must be the LHS of an
    * existing binding in the new function.
@@ -111,12 +120,27 @@ class FuncBuilder : public ExprMutator {
   /*! \brief Build the new function */
   Function Build() {
     Array<Var> params;
+    Optional<Var> shape_expr = NullOpt;
+    if (shape_expr_inputs_.size()) {
+      Array<PrimExpr> tir_vars;
+      for (const auto* var : shape_expr_inputs_) {
+        auto new_var = GetRef<tir::Var>(var).copy_with_suffix("");
+        tir_var_remap_.Set(GetRef<tir::Var>(var), new_var);
+        tir_vars.push_back(new_var);
+      }
+      shape_expr = Var("shape_expr", ShapeStructInfo(tir_vars));
+    }
     // Set up the parameters
     for (const auto* input : inputs_) {
-      auto new_var = Var(input->name_hint(), 
Downcast<Optional<StructInfo>>(input->struct_info_));
+      auto new_var = Var(
+          input->name_hint(),
+          
VisitExprDepStructInfoField(Downcast<Optional<StructInfo>>(input->struct_info_).value()));
       var_remap_[input->vid] = new_var;
       params.push_back(new_var);
     }
+    if (shape_expr) {
+      params.push_back(shape_expr.value());
+    }
     // Emit the function body
     builder_->BeginBindingBlock();
     for (const auto* binding : bindings_) {
@@ -137,9 +161,13 @@ class FuncBuilder : public ExprMutator {
     return func;
   }
 
+  PrimExpr VisitPrimExpr(const PrimExpr& expr) { return tir::Substitute(expr, 
tir_var_remap_); }
+
   support::OrderedSet<const VarNode*> inputs_;
   support::OrderedSet<const VarNode*> outputs_;
+  support::OrderedSet<const tir::VarNode*> shape_expr_inputs_;
   std::vector<const VarBindingNode*> bindings_;
+  Map<tir::Var, PrimExpr> tir_var_remap_;
 };
 
 /*!
@@ -159,6 +187,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
             static_vars_.insert(func->params[i].get());
           }
         }
+        CollectSymbolicVarHints(func);
         VisitExpr(func);
       }
     }
@@ -174,6 +203,13 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
       for (const auto* binding : region->bindings_) {
         plan.lifted_bindings.insert(binding->var.get());
       }
+      if (region->shape_expr_inputs_.size()) {
+        Array<PrimExpr> tir_vars;
+        for (const auto* var : region->shape_expr_inputs_) {
+          tir_vars.push_back(GetRef<PrimExpr>(var));
+        }
+        plan.propogated_tir_vars = ShapeExpr(tir_vars);
+      }
       plan.inputs.assign(region->inputs_.begin(), region->inputs_.end());
       plan.outputs.assign(region->outputs_.begin(), region->outputs_.end());
       return plan;
@@ -189,6 +225,18 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
     return plans;
   }
 
+  /*!
+   * \brief Collect the name hints of the symbolic variables that are allowed 
to be captured.
+   */
+  void CollectSymbolicVarHints(const Function& func) {
+    capture_symbolic_vars_.clear();
+    if (auto symbolic_vars =
+            
func->attrs.GetAttr<Array<String>>("relax.rewrite_cuda_graph.capture_symbolic_vars"))
 {
+      for (const auto& var : symbolic_vars.value()) {
+        capture_symbolic_vars_.insert(var);
+      }
+    }
+  }
   /*!
    *\brief Start a new static region. This method should be called when 
encountering a
    * CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on 
static parameters.
@@ -239,8 +287,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
     // Check whether the call can be lifted to the capture function. It 
requires all the arguments
     // to be static and the call to be a kernel launch or a pure operation 
(e.g. memory view).
     std::vector<const VarNode*> args;
+    std::vector<const tir::VarNode*> tir_vars;
     bool is_all_static = [&]() {
-      if (!IsStatic(call->args, &args)) {
+      if (!IsStatic(call->args, &args, &tir_vars)) {
         return false;
       }
       if (call_gv != nullptr && !call_prim_func) {
@@ -276,7 +325,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
         StartRegion();
       }
       AddStaticBinding(binding, /*is_alloc_storage=*/false);
-      MarkAsFuncInput(args);
+      MarkAsFuncInput(args, tir_vars);
     } else {
       EndRegion();
     }
@@ -284,7 +333,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
     MarkAsFuncOutput(args);
   }
 
-  void MarkAsFuncInput(const std::vector<const VarNode*>& vars) {
+  void MarkAsFuncInput(const std::vector<const VarNode*>& vars,
+                       const std::vector<const tir::VarNode*>& tir_vars = {}) {
     if (current_.capture_builder == nullptr) {
       return;
     }
@@ -294,6 +344,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
         current_.capture_builder->MarkInput(var);
       }
     }
+    for (const tir::VarNode* tir_var : tir_vars) {
+      current_.capture_builder->MarkShapeExprInput(tir_var);
+    }
   }
 
   void MarkAsFuncOutput(const std::vector<const VarNode*>& vars) {
@@ -321,9 +374,10 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
 
   void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) 
final {
     std::vector<const VarNode*> args;
-    if (IsStatic(tuple->fields, &args)) {
+    std::vector<const tir::VarNode*> tir_vars;
+    if (IsStatic(tuple->fields, &args, &tir_vars)) {
       AddStaticBinding(binding, false);
-      MarkAsFuncInput(args);
+      MarkAsFuncInput(args, tir_vars);
     } else {
       EndRegion();
     }
@@ -343,48 +397,83 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
   }
 
   bool IsStatic(const PrimExpr& expr,
-                [[maybe_unused]] std::vector<const VarNode*>* vars_collector = 
nullptr) {
-    return expr->IsInstance<tir::IntImmNode>() || 
expr->IsInstance<tir::FloatImmNode>();
+                [[maybe_unused]] std::vector<const VarNode*>* vars_collector = 
nullptr,
+                std::vector<const tir::VarNode*>* tir_vars_collector = 
nullptr) {
+    bool is_static = true;
+    tir::PostOrderVisit(expr, [&](const ObjectRef& e) {
+      if (auto var = e.as<tir::VarNode>()) {
+        if (!capture_symbolic_vars_.count(var->name_hint)) {
+          is_static = false;
+          return;
+        }
+        if (tir_vars_collector != nullptr) {
+          tir_vars_collector->push_back(var);
+        }
+      }
+    });
+    return is_static;
   }
 
-  bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector 
= nullptr) {
+  bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector 
= nullptr,
+                std::vector<const tir::VarNode*>* tir_vars_collector = 
nullptr) {
     if (expr->IsInstance<ConstantNode>() || 
expr->IsInstance<DataTypeImmNode>() ||
-        expr->IsInstance<StringImmNode>()) {
+        expr->IsInstance<StringImmNode>() || 
expr->IsInstance<GlobalVarNode>()) {
       return true;
     }
     if (const auto* prim_value = expr.as<PrimValueNode>()) {
-      return IsStatic(prim_value->value, vars_collector);
+      return IsStatic(prim_value->value, vars_collector, tir_vars_collector);
     }
     if (const auto* var = expr.as<VarNode>()) {
       if (vars_collector != nullptr) {
         vars_collector->push_back(var);
       }
-      return static_vars_.count(var);
+      // recursively check the struct info to collect the symbolic TIR vars
+      return static_vars_.count(var) && 
IsStatic(Downcast<StructInfo>(var->struct_info_.value()),
+                                                 vars_collector, 
tir_vars_collector);
     }
 
     if (const auto* shape = expr.as<ShapeExprNode>()) {
-      return IsStatic(shape->values, vars_collector);
+      return IsStatic(shape->values, vars_collector, tir_vars_collector);
     }
     if (const auto* tuple = expr.as<TupleNode>()) {
-      return IsStatic(tuple->fields, vars_collector);
+      return IsStatic(tuple->fields, vars_collector, tir_vars_collector);
     }
     return false;
   }
 
   template <typename T>
-  bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>* 
vars_collector = nullptr) {
+  bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>* 
vars_collector = nullptr,
+                std::vector<const tir::VarNode*>* tir_vars_collector = 
nullptr) {
     bool result = true;
     for (const auto& expr : exprs) {
       // If vars_collector is provided, we will collect all the vars in the 
exprs and we should
       // not perform short-circuiting.
-      result &= IsStatic(expr, vars_collector);
-      if (!vars_collector && !result) {
+      result &= IsStatic(expr, vars_collector, tir_vars_collector);
+      if (vars_collector == nullptr && tir_vars_collector == nullptr && 
!result) {
         return false;
       }
     }
     return result;
   }
 
+  bool IsStatic(const StructInfo& sinfo, std::vector<const VarNode*>* 
vars_collector = nullptr,
+                std::vector<const tir::VarNode*>* tir_vars_collector = 
nullptr) {
+    if (const auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
+      if (auto shape = tensor_sinfo->GetShape()) {
+        return IsStatic(shape.value(), vars_collector, tir_vars_collector);
+      }
+    } else if (const auto* shape_sinfo = sinfo.as<ShapeStructInfoNode>()) {
+      if (shape_sinfo->values) {
+        return IsStatic(shape_sinfo->values.value(), vars_collector, 
tir_vars_collector);
+      }
+    } else if (const auto* tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
+      return IsStatic(tuple_sinfo->fields, vars_collector, tir_vars_collector);
+    } else if (sinfo.as<ObjectStructInfoNode>() || 
sinfo.as<PrimStructInfoNode>()) {
+      return true;
+    }
+    return false;
+  }
+
  private:
   bool IsStaticAllocStorage(const VarBindingNode* binding) {
     // Check if the allocation has constant shape
@@ -431,6 +520,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
   Scope current_;
   // Variables whose buffer address is fixed
   std::unordered_set<const VarNode*> static_vars_;
+  // The name of the variables that are allowed to be symbolic
+  std::unordered_set<String> capture_symbolic_vars_;
   // Binding to the FuncBuilder if the binding is lifted. This is used to 
update the inputs/outputs
   // of the lifted function when its binding is used outside.
   std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
@@ -475,6 +566,8 @@ class CUDAGraphRewriter : public ExprMutator {
     auto gv_func =
         builder_->AddFunction(plan.func, plan.is_alloc ? "cuda_graph_alloc" : 
"cuda_graph_capture");
     if (plan.is_alloc) {
+      // Storage allocation should be fully static and shouldn't depend on any 
symbolic variables.
+      ICHECK(!plan.propogated_tir_vars.defined());
       ICHECK(plan.inputs.empty());
       launch_subgraph =
           Call(call_builtin_with_ctx_op,
@@ -482,15 +575,39 @@ class CUDAGraphRewriter : public ExprMutator {
                 Tuple({gv_func, PrimValue(IntImm(DataType::Int(64), 
index_alloc_++))})},
                Attrs(), {plan.func->ret_struct_info});
     } else {
+      StructInfo call_sinfo = plan.func->ret_struct_info;
+      // Arguments of the lifted function
       Array<Expr> args;
       for (const auto& arg : plan.inputs) {
         args.push_back(VisitExpr_(arg));
       }
-      launch_subgraph = Call(
-          call_builtin_with_ctx_op,
-          {builtin_run_or_capture,
-           Tuple({gv_func, Tuple(args), PrimValue(IntImm(DataType::Int(64), 
index_capture_++))})},
-          Attrs(), {plan.func->ret_struct_info});
+      if (plan.propogated_tir_vars.defined()) {
+        ShapeExpr propogated_tir_vars = plan.propogated_tir_vars.value();
+        args.push_back(propogated_tir_vars);
+        // The ret_struct_info of the lifted function can contain symbolic 
variables. We need to
+        // bind the symbolic parameters to the actual values.
+        const auto& shape_expr = plan.func->params.back();
+        auto symbolic_params =
+            
Downcast<ShapeStructInfo>(shape_expr->struct_info_.value())->values.value();
+        Map<tir::Var, PrimExpr> tir_var_remap;
+        ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size());
+        for (int i = 0; i < static_cast<int>(symbolic_params.size()); ++i) {
+          tir_var_remap.Set(Downcast<tir::Var>(symbolic_params[i]), 
propogated_tir_vars->values[i]);
+        }
+        call_sinfo = Bind(call_sinfo, tir_var_remap);
+      }
+      // Arguments of builtin_run_or_capture
+      Array<Expr> tuple_arg_fields{gv_func, Tuple(args),
+                                   PrimValue(IntImm(DataType::Int(64), 
index_capture_++))};
+      if (plan.propogated_tir_vars.defined()) {
+        // The shape expr is explicitly passed twice, one as the last argument 
of the lifted
+        // function, one as the last argument of builtin_run_or_capture as the 
cache key. Explicitly
+        // passing it twice simplifies the handling during the capture phase.
+        tuple_arg_fields.push_back(plan.propogated_tir_vars.value());
+      }
+      launch_subgraph =
+          Call(call_builtin_with_ctx_op, {builtin_run_or_capture, 
Tuple(tuple_arg_fields)}, Attrs(),
+               {call_sinfo});
     }
     Expr ret_value = builder_->Emit(launch_subgraph);
     for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) {
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index a15ee79fac..77e6b33f0c 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -144,6 +144,10 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& 
binds,
   return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
 }
 
+StructInfo Bind(const StructInfo& sinfo, const tvm::Map<tir::Var, PrimExpr>& 
symbolic_var_map) {
+  return ExprBinder({}, symbolic_var_map).VisitExprDepStructInfoField(sinfo);
+}
+
 tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
     const tvm::Map<relax::Var, relax::Expr>& relax_var_remap, arith::Analyzer* 
analyzer) {
   tvm::Map<tir::Var, PrimExpr> tir_var_remap;
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc 
b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index f6eef9ca25..02b6da7dab 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -26,11 +26,45 @@
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/relax_vm/vm.h>
 
+#include "../../../support/utils.h"
 #include "../../cuda/cuda_common.h"
 namespace tvm {
 namespace runtime {
 namespace relax_vm {
 
+struct CUDAGraphCaptureKey {
+  // The unique index of the capture function within the module
+  int64_t index;
+  // The symbolic variables the capture function depends on. When the capture 
function is ran with
+  // different symbolic variable values, the CUDA graph will be re-captured as 
a different version,
+  // identified by this shape tuple. This is default constructed as an empty 
tuple.
+  ShapeTuple shape_expr;
+
+  CUDAGraphCaptureKey(int64_t index, const Optional<ShapeTuple>& shape_expr) : 
index(index) {
+    if (shape_expr) {
+      this->shape_expr = shape_expr.value();
+    }
+  }
+};
+
+struct CUDAGraphCaptureKeyHash {
+  size_t operator()(const CUDAGraphCaptureKey& key) const {
+    std::hash<int64_t> hash_fn;
+    size_t hash = hash_fn(key.index);
+    for (const auto& shape : key.shape_expr) {
+      support::HashCombine(hash, hash_fn(shape));
+    }
+    return hash;
+  }
+};
+
+struct CUDAGraphCaptureKeyEqual {
+  bool operator()(const CUDAGraphCaptureKey& lhs, const CUDAGraphCaptureKey& 
rhs) const {
+    return lhs.index == rhs.index && std::equal(lhs.shape_expr.begin(), 
lhs.shape_expr.end(),
+                                                rhs.shape_expr.begin(), 
rhs.shape_expr.end());
+  }
+};
+
 /*! \brief The cache states of a CUDA graph. */
 class CUDAGraphCache : public Object {
  public:
@@ -62,8 +96,9 @@ class CUDAGraphCache : public Object {
    * \return The return value of the capture function.
    */
   ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, 
ObjectRef args,
-                         int64_t entry_index) {
-    if (auto it = capture_cache_.find(entry_index); it != 
capture_cache_.end()) {
+                         int64_t entry_index, Optional<ShapeTuple> shape_expr) 
{
+    CUDAGraphCaptureKey entry_key{entry_index, shape_expr};
+    if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) {
       // Launch CUDA graph
       const auto& [states, exec] = it->second;
       CUDA_CALL(cudaGraphLaunch(exec, CUDAThreadEntry::ThreadLocal()->stream));
@@ -103,8 +138,8 @@ class CUDAGraphCache : public Object {
     CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, 
&graph));
     std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
 
-    capture_cache_[entry_index] = entry;
-    CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_index].exec, graph, 
NULL, NULL, 0));
+    capture_cache_[entry_key] = entry;
+    CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_key].exec, graph, 
NULL, NULL, 0));
     CUDA_CALL(cudaStreamDestroy(capture_stream));
     CUDA_CALL(cudaGraphDestroy(graph));
     return entry.states;
@@ -134,7 +169,9 @@ class CUDAGraphCache : public Object {
    * \brief The cache of captured cuda graphs. The key is a unique index for 
the capture function.
    * The value is the result of the capture.
    */
-  std::unordered_map<int64_t, CaptureResult> capture_cache_;
+  std::unordered_map<CUDAGraphCaptureKey, CaptureResult, 
CUDAGraphCaptureKeyHash,
+                     CUDAGraphCaptureKeyEqual>
+      capture_cache_;
   /*!
    * \brief The cache of allocations. The key is a unique index for the 
allocation function.
    * The value is the cached allocations, which is a tuple of storages.
@@ -143,11 +180,18 @@ class CUDAGraphCache : public Object {
 };
 
 TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture")
-    .set_body_typed([](TVMArgValue vm_ptr, ObjectRef capture_func, ObjectRef 
func_args,
-                       int64_t entry_index) {
-      VirtualMachine* vm = VirtualMachine::GetContextPtr(vm_ptr);
+    .set_body([](TVMArgs args, TVMRetValue* rv) {
+      ICHECK(args.size() == 5 || args.size() == 4);
+      VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
+      ObjectRef capture_func = args[1];
+      ObjectRef func_args = args[2];
+      int64_t entry_index = args[3];
+      Optional<ShapeTuple> shape_expr = NullOpt;
+      if (args.size() == 5) {
+        shape_expr = args[4].AsObjectRef<ShapeTuple>();
+      }
       CUDAGraphCache* cache = CUDAGraphCache::Get();
-      return cache->RunOrCapture(vm, capture_func, func_args, entry_index);
+      *rv = cache->RunOrCapture(vm, capture_func, func_args, entry_index, 
shape_expr);
     });
 
 TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc")
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py 
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 91b3fce264..43b26f110f 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -757,5 +757,123 @@ def test_static_args():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_dynamic_capture():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def add_one(x_handle: T.handle, y_handle: T.handle):
+            m = T.int64()
+            x = T.match_buffer(x_handle, (m,), "float32")
+            y = T.match_buffer(y_handle, (m,), "float32")
+            for i in range(m):
+                with T.block("add"):
+                    vi = T.axis.remap("S", [i])
+                    y[vi] = x[vi] + T.float32(1)
+
+        @R.function
+        def main(x: R.Tensor(("m",), "float32")) -> R.Tensor(("m",), 
"float32"):
+            R.func_attr(
+                {"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"], 
"relax.force_pure": True}
+            )
+            m = T.int64()
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([16]), 0, "global", "float32"
+            )  # assume m is upper-bounded
+            alloc1: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+                storage, 0, R.shape([m]), "float32"
+            )
+            _ = Before.add_one(x, alloc1)
+            storage1: R.Object = R.memory.alloc_storage(R.shape([16]), 0, 
"global", "float32")
+            alloc2: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+                storage1, 0, R.shape([m]), "float32"
+            )
+            _ = Before.add_one(alloc1, alloc2)
+            alloc3: R.Tensor((m,), "float32") = R.builtin.alloc_tensor(
+                R.shape([m]), "float32", 0, "global"
+            )
+            _ = Before.add_one(alloc2, alloc3)
+            return alloc3
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def add_one(x_handle: T.handle, y_handle: T.handle):
+            m = T.int64()
+            x = T.match_buffer(x_handle, (m,))
+            y = T.match_buffer(y_handle, (m,))
+            # with T.block("root"):
+            for i in range(m):
+                with T.block("add"):
+                    vi = T.axis.spatial(m, i)
+                    T.reads(x[vi])
+                    T.writes(y[vi])
+                    y[vi] = x[vi] + T.float32(1)
+
+        @R.function(private=True)
+        def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+            R.func_attr({"relax.force_pure": True})
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([16]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            storage1: R.Object = R.memory.alloc_storage(
+                R.shape([16]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            gv: R.Tuple(R.Object, R.Object) = storage, storage1
+            return gv
+
+        @R.function(private=True)
+        def cuda_graph_capture(
+            alloc1: R.Tensor(("m",), dtype="float32"),
+            alloc2: R.Tensor(("m",), dtype="float32"),
+            shape_expr: R.Shape(["m"]),
+        ):
+            m = T.int64()
+            R.func_attr({"relax.force_pure": True})
+            cls = Expected
+            cls.add_one(alloc1, alloc2)
+            gv = R.tuple()
+            return R.tuple()
+
+        @R.function
+        def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), 
dtype="float32"):
+            m = T.int64()
+            R.func_attr(
+                {"relax.force_pure": True, 
"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"]}
+            )
+            cls = Expected
+            gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.get_cached_alloc",
+                (cls.cuda_graph_alloc, R.prim_value(0)),
+                sinfo_args=(R.Tuple(R.Object, R.Object),),
+            )
+            storage: R.Object = gv[0]
+            alloc1: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+                storage, R.prim_value(0), R.shape([m]), R.dtype("float32")
+            )
+            cls.add_one(x, alloc1)
+            storage1: R.Object = gv[1]
+            alloc2: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+                storage1, R.prim_value(0), R.shape([m]), R.dtype("float32")
+            )
+            R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.run_or_capture",
+                (
+                    cls.cuda_graph_capture,
+                    (alloc1, alloc2, R.shape([m])),
+                    R.prim_value(0),
+                    R.shape([m]),
+                ),
+                sinfo_args=(R.Tuple,),
+            )
+            alloc3: R.Tensor((m,), dtype="float32") = R.builtin.alloc_tensor(
+                R.shape([m]), R.dtype("float32"), R.prim_value(0), 
R.str("global")
+            )
+            cls.add_one(alloc2, alloc3)
+            return alloc3
+
+    mod = relax.transform.RewriteCUDAGraph()(Before)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to