This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ef32a611e3 [Relax] Enable capturing symbolic shapes in cuda graph
(#16815)
ef32a611e3 is described below
commit ef32a611e386251c86fa255db4b8530b291dde11
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Mar 30 14:31:22 2024 -0700
[Relax] Enable capturing symbolic shapes in cuda graph (#16815)
* [Relax] Enable capturing symbolic shapes in cuda graph
* Add Bind sinfo util
* Bind ret sinfo
* address comments
* add comments
* fix
* update test
---
include/tvm/relax/utils.h | 7 +
src/relax/transform/rewrite_cuda_graph.cc | 161 ++++++++++++++++++---
src/relax/utils.cc | 4 +
src/runtime/relax_vm/cuda/cuda_graph_builtin.cc | 62 ++++++--
.../relax/test_transform_rewrite_cuda_graph.py | 118 +++++++++++++++
5 files changed, 321 insertions(+), 31 deletions(-)
diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index 74e773abe7..e48c1856f9 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -50,6 +50,13 @@ namespace relax {
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map = {});
+/*!
+ * \brief Bind the symbolic variables to a StructInfo. This is a helper
function usually called by
+ * other pass functions to help optimizations.
+ */
+TVM_DLL StructInfo Bind(const StructInfo& sinfo,
+ const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map);
+
/*!
* \brief Infer a binding map for symbolic variables
*
diff --git a/src/relax/transform/rewrite_cuda_graph.cc
b/src/relax/transform/rewrite_cuda_graph.cc
index b67a638dd6..25b229ebce 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -53,6 +53,8 @@
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
#include "../../support/arena.h"
#include "../../support/ordered_set.h"
@@ -82,6 +84,8 @@ struct LiftedFunctionRewritePlan {
std::vector<const VarNode*> outputs;
// The corresponding binding vars in the original function of the inputs of
the lifted function
std::vector<const VarNode*> inputs;
+ // The tir vars in the original function that are propagated to the lifted
function
+ Optional<ShapeExpr> propogated_tir_vars = NullOpt;
};
/*! \brief Builder of the lifted function for cuda graph capturing or
allocations */
@@ -98,6 +102,11 @@ class FuncBuilder : public ExprMutator {
* \param var The variable to mark as input
*/
void MarkInput(const VarNode* var) { inputs_.push_back(var); }
+ /*!
+ * \brief Mark a TIR variable as the ShapeExpr input of the new function.
+ * \param var The variable to mark as input
+ */
+ void MarkShapeExprInput(const tir::VarNode* var) {
shape_expr_inputs_.push_back(var); }
/*!
* \brief Mark a variable as the output of the new function. The variable
must be the LHS of an
* existing binding in the new function.
@@ -111,12 +120,27 @@ class FuncBuilder : public ExprMutator {
/*! \brief Build the new function */
Function Build() {
Array<Var> params;
+ Optional<Var> shape_expr = NullOpt;
+ if (shape_expr_inputs_.size()) {
+ Array<PrimExpr> tir_vars;
+ for (const auto* var : shape_expr_inputs_) {
+ auto new_var = GetRef<tir::Var>(var).copy_with_suffix("");
+ tir_var_remap_.Set(GetRef<tir::Var>(var), new_var);
+ tir_vars.push_back(new_var);
+ }
+ shape_expr = Var("shape_expr", ShapeStructInfo(tir_vars));
+ }
// Set up the parameters
for (const auto* input : inputs_) {
- auto new_var = Var(input->name_hint(),
Downcast<Optional<StructInfo>>(input->struct_info_));
+ auto new_var = Var(
+ input->name_hint(),
+
VisitExprDepStructInfoField(Downcast<Optional<StructInfo>>(input->struct_info_).value()));
var_remap_[input->vid] = new_var;
params.push_back(new_var);
}
+ if (shape_expr) {
+ params.push_back(shape_expr.value());
+ }
// Emit the function body
builder_->BeginBindingBlock();
for (const auto* binding : bindings_) {
@@ -137,9 +161,13 @@ class FuncBuilder : public ExprMutator {
return func;
}
+ PrimExpr VisitPrimExpr(const PrimExpr& expr) { return tir::Substitute(expr,
tir_var_remap_); }
+
support::OrderedSet<const VarNode*> inputs_;
support::OrderedSet<const VarNode*> outputs_;
+ support::OrderedSet<const tir::VarNode*> shape_expr_inputs_;
std::vector<const VarBindingNode*> bindings_;
+ Map<tir::Var, PrimExpr> tir_var_remap_;
};
/*!
@@ -159,6 +187,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
static_vars_.insert(func->params[i].get());
}
}
+ CollectSymbolicVarHints(func);
VisitExpr(func);
}
}
@@ -174,6 +203,13 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
for (const auto* binding : region->bindings_) {
plan.lifted_bindings.insert(binding->var.get());
}
+ if (region->shape_expr_inputs_.size()) {
+ Array<PrimExpr> tir_vars;
+ for (const auto* var : region->shape_expr_inputs_) {
+ tir_vars.push_back(GetRef<PrimExpr>(var));
+ }
+ plan.propogated_tir_vars = ShapeExpr(tir_vars);
+ }
plan.inputs.assign(region->inputs_.begin(), region->inputs_.end());
plan.outputs.assign(region->outputs_.begin(), region->outputs_.end());
return plan;
@@ -189,6 +225,18 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
return plans;
}
+ /*!
+ * \brief Collect the name hints of the symbolic variables that are allowed
to be captured.
+ */
+ void CollectSymbolicVarHints(const Function& func) {
+ capture_symbolic_vars_.clear();
+ if (auto symbolic_vars =
+
func->attrs.GetAttr<Array<String>>("relax.rewrite_cuda_graph.capture_symbolic_vars"))
{
+ for (const auto& var : symbolic_vars.value()) {
+ capture_symbolic_vars_.insert(var);
+ }
+ }
+ }
/*!
*\brief Start a new static region. This method should be called when
encountering a
* CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on
static parameters.
@@ -239,8 +287,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
// Check whether the call can be lifted to the capture function. It
requires all the arguments
// to be static and the call to be a kernel launch or a pure operation
(e.g. memory view).
std::vector<const VarNode*> args;
+ std::vector<const tir::VarNode*> tir_vars;
bool is_all_static = [&]() {
- if (!IsStatic(call->args, &args)) {
+ if (!IsStatic(call->args, &args, &tir_vars)) {
return false;
}
if (call_gv != nullptr && !call_prim_func) {
@@ -276,7 +325,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
StartRegion();
}
AddStaticBinding(binding, /*is_alloc_storage=*/false);
- MarkAsFuncInput(args);
+ MarkAsFuncInput(args, tir_vars);
} else {
EndRegion();
}
@@ -284,7 +333,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
MarkAsFuncOutput(args);
}
- void MarkAsFuncInput(const std::vector<const VarNode*>& vars) {
+ void MarkAsFuncInput(const std::vector<const VarNode*>& vars,
+ const std::vector<const tir::VarNode*>& tir_vars = {}) {
if (current_.capture_builder == nullptr) {
return;
}
@@ -294,6 +344,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
current_.capture_builder->MarkInput(var);
}
}
+ for (const tir::VarNode* tir_var : tir_vars) {
+ current_.capture_builder->MarkShapeExprInput(tir_var);
+ }
}
void MarkAsFuncOutput(const std::vector<const VarNode*>& vars) {
@@ -321,9 +374,10 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple)
final {
std::vector<const VarNode*> args;
- if (IsStatic(tuple->fields, &args)) {
+ std::vector<const tir::VarNode*> tir_vars;
+ if (IsStatic(tuple->fields, &args, &tir_vars)) {
AddStaticBinding(binding, false);
- MarkAsFuncInput(args);
+ MarkAsFuncInput(args, tir_vars);
} else {
EndRegion();
}
@@ -343,48 +397,83 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
}
bool IsStatic(const PrimExpr& expr,
- [[maybe_unused]] std::vector<const VarNode*>* vars_collector =
nullptr) {
- return expr->IsInstance<tir::IntImmNode>() ||
expr->IsInstance<tir::FloatImmNode>();
+ [[maybe_unused]] std::vector<const VarNode*>* vars_collector =
nullptr,
+ std::vector<const tir::VarNode*>* tir_vars_collector =
nullptr) {
+ bool is_static = true;
+ tir::PostOrderVisit(expr, [&](const ObjectRef& e) {
+ if (auto var = e.as<tir::VarNode>()) {
+ if (!capture_symbolic_vars_.count(var->name_hint)) {
+ is_static = false;
+ return;
+ }
+ if (tir_vars_collector != nullptr) {
+ tir_vars_collector->push_back(var);
+ }
+ }
+ });
+ return is_static;
}
- bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector
= nullptr) {
+ bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector
= nullptr,
+ std::vector<const tir::VarNode*>* tir_vars_collector =
nullptr) {
if (expr->IsInstance<ConstantNode>() ||
expr->IsInstance<DataTypeImmNode>() ||
- expr->IsInstance<StringImmNode>()) {
+ expr->IsInstance<StringImmNode>() ||
expr->IsInstance<GlobalVarNode>()) {
return true;
}
if (const auto* prim_value = expr.as<PrimValueNode>()) {
- return IsStatic(prim_value->value, vars_collector);
+ return IsStatic(prim_value->value, vars_collector, tir_vars_collector);
}
if (const auto* var = expr.as<VarNode>()) {
if (vars_collector != nullptr) {
vars_collector->push_back(var);
}
- return static_vars_.count(var);
+ // recursively check the struct info to collect the symbolic TIR vars
+ return static_vars_.count(var) &&
IsStatic(Downcast<StructInfo>(var->struct_info_.value()),
+ vars_collector,
tir_vars_collector);
}
if (const auto* shape = expr.as<ShapeExprNode>()) {
- return IsStatic(shape->values, vars_collector);
+ return IsStatic(shape->values, vars_collector, tir_vars_collector);
}
if (const auto* tuple = expr.as<TupleNode>()) {
- return IsStatic(tuple->fields, vars_collector);
+ return IsStatic(tuple->fields, vars_collector, tir_vars_collector);
}
return false;
}
template <typename T>
- bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>*
vars_collector = nullptr) {
+ bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>*
vars_collector = nullptr,
+ std::vector<const tir::VarNode*>* tir_vars_collector =
nullptr) {
bool result = true;
for (const auto& expr : exprs) {
// If vars_collector is provided, we will collect all the vars in the
exprs and we should
// not perform short-circuiting.
- result &= IsStatic(expr, vars_collector);
- if (!vars_collector && !result) {
+ result &= IsStatic(expr, vars_collector, tir_vars_collector);
+ if (vars_collector == nullptr && tir_vars_collector == nullptr &&
!result) {
return false;
}
}
return result;
}
+ bool IsStatic(const StructInfo& sinfo, std::vector<const VarNode*>*
vars_collector = nullptr,
+ std::vector<const tir::VarNode*>* tir_vars_collector =
nullptr) {
+ if (const auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
+ if (auto shape = tensor_sinfo->GetShape()) {
+ return IsStatic(shape.value(), vars_collector, tir_vars_collector);
+ }
+ } else if (const auto* shape_sinfo = sinfo.as<ShapeStructInfoNode>()) {
+ if (shape_sinfo->values) {
+ return IsStatic(shape_sinfo->values.value(), vars_collector,
tir_vars_collector);
+ }
+ } else if (const auto* tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
+ return IsStatic(tuple_sinfo->fields, vars_collector, tir_vars_collector);
+ } else if (sinfo.as<ObjectStructInfoNode>() ||
sinfo.as<PrimStructInfoNode>()) {
+ return true;
+ }
+ return false;
+ }
+
private:
bool IsStaticAllocStorage(const VarBindingNode* binding) {
// Check if the allocation has constant shape
@@ -431,6 +520,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
Scope current_;
// Variables whose buffer address is fixed
std::unordered_set<const VarNode*> static_vars_;
+ // The name of the variables that are allowed to be symbolic
+ std::unordered_set<String> capture_symbolic_vars_;
// Binding to the FuncBuilder if the binding is lifted. This is used to
update the inputs/outputs
// of the lifted function when its binding is used outside.
std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
@@ -475,6 +566,8 @@ class CUDAGraphRewriter : public ExprMutator {
auto gv_func =
builder_->AddFunction(plan.func, plan.is_alloc ? "cuda_graph_alloc" :
"cuda_graph_capture");
if (plan.is_alloc) {
+ // Storage allocation should be fully static and shouldn't depend on any
symbolic variables.
+ ICHECK(!plan.propogated_tir_vars.defined());
ICHECK(plan.inputs.empty());
launch_subgraph =
Call(call_builtin_with_ctx_op,
@@ -482,15 +575,39 @@ class CUDAGraphRewriter : public ExprMutator {
Tuple({gv_func, PrimValue(IntImm(DataType::Int(64),
index_alloc_++))})},
Attrs(), {plan.func->ret_struct_info});
} else {
+ StructInfo call_sinfo = plan.func->ret_struct_info;
+ // Arguments of the lifted function
Array<Expr> args;
for (const auto& arg : plan.inputs) {
args.push_back(VisitExpr_(arg));
}
- launch_subgraph = Call(
- call_builtin_with_ctx_op,
- {builtin_run_or_capture,
- Tuple({gv_func, Tuple(args), PrimValue(IntImm(DataType::Int(64),
index_capture_++))})},
- Attrs(), {plan.func->ret_struct_info});
+ if (plan.propogated_tir_vars.defined()) {
+ ShapeExpr propogated_tir_vars = plan.propogated_tir_vars.value();
+ args.push_back(propogated_tir_vars);
+ // The ret_struct_info of the lifted function can contain symbolic
variables. We need to
+ // bind the symbolic parameters to the actual values.
+ const auto& shape_expr = plan.func->params.back();
+ auto symbolic_params =
+
Downcast<ShapeStructInfo>(shape_expr->struct_info_.value())->values.value();
+ Map<tir::Var, PrimExpr> tir_var_remap;
+ ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size());
+ for (int i = 0; i < static_cast<int>(symbolic_params.size()); ++i) {
+ tir_var_remap.Set(Downcast<tir::Var>(symbolic_params[i]),
propogated_tir_vars->values[i]);
+ }
+ call_sinfo = Bind(call_sinfo, tir_var_remap);
+ }
+ // Arguments of builtin_run_or_capture
+ Array<Expr> tuple_arg_fields{gv_func, Tuple(args),
+ PrimValue(IntImm(DataType::Int(64),
index_capture_++))};
+ if (plan.propogated_tir_vars.defined()) {
+ // The shape expr is explicitly passed twice, one as the last argument
of the lifted
+ // function, one as the last argument of builtin_run_or_capture as the
cache key. Explicitly
+ // passing it twice simplifies the handling during the capture phase.
+ tuple_arg_fields.push_back(plan.propogated_tir_vars.value());
+ }
+ launch_subgraph =
+ Call(call_builtin_with_ctx_op, {builtin_run_or_capture,
Tuple(tuple_arg_fields)}, Attrs(),
+ {call_sinfo});
}
Expr ret_value = builder_->Emit(launch_subgraph);
for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) {
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index a15ee79fac..77e6b33f0c 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -144,6 +144,10 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>&
binds,
return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
}
+StructInfo Bind(const StructInfo& sinfo, const tvm::Map<tir::Var, PrimExpr>&
symbolic_var_map) {
+ return ExprBinder({}, symbolic_var_map).VisitExprDepStructInfoField(sinfo);
+}
+
tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
const tvm::Map<relax::Var, relax::Expr>& relax_var_remap, arith::Analyzer*
analyzer) {
tvm::Map<tir::Var, PrimExpr> tir_var_remap;
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index f6eef9ca25..02b6da7dab 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -26,11 +26,45 @@
#include <tvm/runtime/registry.h>
#include <tvm/runtime/relax_vm/vm.h>
+#include "../../../support/utils.h"
#include "../../cuda/cuda_common.h"
namespace tvm {
namespace runtime {
namespace relax_vm {
+struct CUDAGraphCaptureKey {
+ // The unique index of the capture function within the module
+ int64_t index;
+ // The symbolic variables the capture function depends on. When the capture
function is ran with
+ // different symbolic variable values, the CUDA graph will be re-captured as
a different version,
+ // identified by this shape tuple. This is default constructed as an empty
tuple.
+ ShapeTuple shape_expr;
+
+ CUDAGraphCaptureKey(int64_t index, const Optional<ShapeTuple>& shape_expr) :
index(index) {
+ if (shape_expr) {
+ this->shape_expr = shape_expr.value();
+ }
+ }
+};
+
+struct CUDAGraphCaptureKeyHash {
+ size_t operator()(const CUDAGraphCaptureKey& key) const {
+ std::hash<int64_t> hash_fn;
+ size_t hash = hash_fn(key.index);
+ for (const auto& shape : key.shape_expr) {
+ support::HashCombine(hash, hash_fn(shape));
+ }
+ return hash;
+ }
+};
+
+struct CUDAGraphCaptureKeyEqual {
+ bool operator()(const CUDAGraphCaptureKey& lhs, const CUDAGraphCaptureKey&
rhs) const {
+ return lhs.index == rhs.index && std::equal(lhs.shape_expr.begin(),
lhs.shape_expr.end(),
+ rhs.shape_expr.begin(),
rhs.shape_expr.end());
+ }
+};
+
/*! \brief The cache states of a CUDA graph. */
class CUDAGraphCache : public Object {
public:
@@ -62,8 +96,9 @@ class CUDAGraphCache : public Object {
* \return The return value of the capture function.
*/
ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func,
ObjectRef args,
- int64_t entry_index) {
- if (auto it = capture_cache_.find(entry_index); it !=
capture_cache_.end()) {
+ int64_t entry_index, Optional<ShapeTuple> shape_expr)
{
+ CUDAGraphCaptureKey entry_key{entry_index, shape_expr};
+ if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) {
// Launch CUDA graph
const auto& [states, exec] = it->second;
CUDA_CALL(cudaGraphLaunch(exec, CUDAThreadEntry::ThreadLocal()->stream));
@@ -103,8 +138,8 @@ class CUDAGraphCache : public Object {
CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream,
&graph));
std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
- capture_cache_[entry_index] = entry;
- CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_index].exec, graph,
NULL, NULL, 0));
+ capture_cache_[entry_key] = entry;
+ CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_key].exec, graph,
NULL, NULL, 0));
CUDA_CALL(cudaStreamDestroy(capture_stream));
CUDA_CALL(cudaGraphDestroy(graph));
return entry.states;
@@ -134,7 +169,9 @@ class CUDAGraphCache : public Object {
* \brief The cache of captured cuda graphs. The key is a unique index for
the capture function.
* The value is the result of the capture.
*/
- std::unordered_map<int64_t, CaptureResult> capture_cache_;
+ std::unordered_map<CUDAGraphCaptureKey, CaptureResult,
CUDAGraphCaptureKeyHash,
+ CUDAGraphCaptureKeyEqual>
+ capture_cache_;
/*!
* \brief The cache of allocations. The key is a unique index for the
allocation function.
* The value is the cached allocations, which is a tuple of storages.
@@ -143,11 +180,18 @@ class CUDAGraphCache : public Object {
};
TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture")
- .set_body_typed([](TVMArgValue vm_ptr, ObjectRef capture_func, ObjectRef
func_args,
- int64_t entry_index) {
- VirtualMachine* vm = VirtualMachine::GetContextPtr(vm_ptr);
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ ICHECK(args.size() == 5 || args.size() == 4);
+ VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
+ ObjectRef capture_func = args[1];
+ ObjectRef func_args = args[2];
+ int64_t entry_index = args[3];
+ Optional<ShapeTuple> shape_expr = NullOpt;
+ if (args.size() == 5) {
+ shape_expr = args[4].AsObjectRef<ShapeTuple>();
+ }
CUDAGraphCache* cache = CUDAGraphCache::Get();
- return cache->RunOrCapture(vm, capture_func, func_args, entry_index);
+ *rv = cache->RunOrCapture(vm, capture_func, func_args, entry_index,
shape_expr);
});
TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc")
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 91b3fce264..43b26f110f 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -757,5 +757,123 @@ def test_static_args():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_dynamic_capture():
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def add_one(x_handle: T.handle, y_handle: T.handle):
+ m = T.int64()
+ x = T.match_buffer(x_handle, (m,), "float32")
+ y = T.match_buffer(y_handle, (m,), "float32")
+ for i in range(m):
+ with T.block("add"):
+ vi = T.axis.remap("S", [i])
+ y[vi] = x[vi] + T.float32(1)
+
+ @R.function
+ def main(x: R.Tensor(("m",), "float32")) -> R.Tensor(("m",),
"float32"):
+ R.func_attr(
+ {"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"],
"relax.force_pure": True}
+ )
+ m = T.int64()
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([16]), 0, "global", "float32"
+ ) # assume m is upper-bounded
+ alloc1: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([m]), "float32"
+ )
+ _ = Before.add_one(x, alloc1)
+ storage1: R.Object = R.memory.alloc_storage(R.shape([16]), 0,
"global", "float32")
+ alloc2: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+ storage1, 0, R.shape([m]), "float32"
+ )
+ _ = Before.add_one(alloc1, alloc2)
+ alloc3: R.Tensor((m,), "float32") = R.builtin.alloc_tensor(
+ R.shape([m]), "float32", 0, "global"
+ )
+ _ = Before.add_one(alloc2, alloc3)
+ return alloc3
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def add_one(x_handle: T.handle, y_handle: T.handle):
+ m = T.int64()
+ x = T.match_buffer(x_handle, (m,))
+ y = T.match_buffer(y_handle, (m,))
+ # with T.block("root"):
+ for i in range(m):
+ with T.block("add"):
+ vi = T.axis.spatial(m, i)
+ T.reads(x[vi])
+ T.writes(y[vi])
+ y[vi] = x[vi] + T.float32(1)
+
+ @R.function(private=True)
+ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+ R.func_attr({"relax.force_pure": True})
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([16]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([16]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ gv: R.Tuple(R.Object, R.Object) = storage, storage1
+ return gv
+
+ @R.function(private=True)
+ def cuda_graph_capture(
+ alloc1: R.Tensor(("m",), dtype="float32"),
+ alloc2: R.Tensor(("m",), dtype="float32"),
+ shape_expr: R.Shape(["m"]),
+ ):
+ m = T.int64()
+ R.func_attr({"relax.force_pure": True})
+ cls = Expected
+ cls.add_one(alloc1, alloc2)
+ gv = R.tuple()
+ return R.tuple()
+
+ @R.function
+ def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",),
dtype="float32"):
+ m = T.int64()
+ R.func_attr(
+ {"relax.force_pure": True,
"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"]}
+ )
+ cls = Expected
+ gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.get_cached_alloc",
+ (cls.cuda_graph_alloc, R.prim_value(0)),
+ sinfo_args=(R.Tuple(R.Object, R.Object),),
+ )
+ storage: R.Object = gv[0]
+ alloc1: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+ storage, R.prim_value(0), R.shape([m]), R.dtype("float32")
+ )
+ cls.add_one(x, alloc1)
+ storage1: R.Object = gv[1]
+ alloc2: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+ storage1, R.prim_value(0), R.shape([m]), R.dtype("float32")
+ )
+ R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.run_or_capture",
+ (
+ cls.cuda_graph_capture,
+ (alloc1, alloc2, R.shape([m])),
+ R.prim_value(0),
+ R.shape([m]),
+ ),
+ sinfo_args=(R.Tuple,),
+ )
+ alloc3: R.Tensor((m,), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([m]), R.dtype("float32"), R.prim_value(0),
R.str("global")
+ )
+ cls.add_one(alloc2, alloc3)
+ return alloc3
+
+ mod = relax.transform.RewriteCUDAGraph()(Before)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()