This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bb1e10c81f [Unity] Add rewriting for CUDA graph capturing (#14513)
bb1e10c81f is described below
commit bb1e10c81f2eb9e5e67c3ecc6310c60d247ea283
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Apr 19 11:56:24 2023 -0700
[Unity] Add rewriting for CUDA graph capturing (#14513)
* [Unity] Add rewriting for CUDA graph capturing
* address comments
* Update python/tvm/relax/transform/transform.py
Co-authored-by: Siyuan Feng <[email protected]>
* address comments
---------
Co-authored-by: Siyuan Feng <[email protected]>
---
include/tvm/relax/transform.h | 7 +
python/tvm/relax/transform/transform.py | 12 +
python/tvm/relax/vm_build.py | 4 +
src/relax/transform/rewrite_cuda_graph.cc | 512 +++++++++++++++++++++
src/support/ordered_set.h | 68 +++
.../relax/test_transform_rewrite_cuda_graph.py | 228 +++++++++
6 files changed, 831 insertions(+)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 9a9d1eb54e..27bd1bd702 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -483,6 +483,13 @@ TVM_DLL Pass DeadCodeElimination(Array<runtime::String>
entry_functions);
*/
TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype);
+/*!
+ * \brief Rewrite a Relax module for executing with CUDA graph. This pass
identifies
+ * the regions that can be executed with CUDA graph and lifts them into new
functions for runtime
+ * graph capturing.
+ */
+TVM_DLL Pass RewriteCUDAGraph();
+
} // namespace transform
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index b17f2fe62b..870b731883 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -964,6 +964,18 @@ def CombineParallelMatmul():
return _ffi_api.CombineParallelMatmul() # type: ignore
+def RewriteCUDAGraph() -> tvm.ir.transform.Pass:
+ """Rewrite a Relax module for executing with CUDA graph. This pass
identifies the regions that
+ can be executed with CUDA graph and lifts them into new functions for
runtime graph capturing.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ The registered pass for rewriting cuda graph
+ """
+ return _ffi_api.RewriteCUDAGraph() # type: ignore
+
+
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index 0586bf9217..c89c64cd81 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -297,6 +297,10 @@ def build(
passes.append(relax.transform.ToNonDataflow())
passes.append(relax.transform.CallTIRRewrite())
passes.append(relax.transform.StaticPlanBlockMemory())
+
+ if
tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph",
False):
+ passes.append(relax.transform.RewriteCUDAGraph())
+
passes.append(relax.transform.VMBuiltinLower())
passes.append(relax.transform.VMShapeLower())
passes.append(relax.transform.AttachGlobalSymbol())
diff --git a/src/relax/transform/rewrite_cuda_graph.cc
b/src/relax/transform/rewrite_cuda_graph.cc
new file mode 100644
index 0000000000..9621d9ff58
--- /dev/null
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -0,0 +1,512 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/transform/rewrite_cuda_graph.cc
+ * \brief Pass for transforming Relax module to execute with CUDA graph.
+ *
+ * CUDA graph provides a way to capture a sequence of CUDA kernel launches in
the runtime and
+ * save them as a graph. The graph can be executed multiple times with less
overhead than launching
+ * kernels individually. This pass rewrites the Relax module to execute with
CUDA graph.
+ *
+ * The transformation is done in two steps:
+ *
+ * 1. Identify the regions that can be captured by CUDA graph and create the
rewriting plan with
+ * `CUDAGraphRewritePlanner`.
+ *
+ * A region is a subgraph of the Relax function that are executed statically.
A region is executed
+ * statically if 1) it only depends on the memory allocated internally in the
Relax function with
+ * constant shapes, 2) it only contains kernel launches (there are no control
flow).
+ *
+ * This transformation is expected to run after `StaticPlanBlockMemory`. After
+ * `StaticPlanBlockMemory`, all the tensors that can be statically allocated
are allocated with
+ * `R.memory.alloc_storage` and `R.memory.alloc_tensor`, while other tensors
will be allocated via
+ * `R.builtin.alloc_tensor`.
+ *
+ * `CUDAGraphRewritePlanner` is executed at the level of BindingBlock. It
first identify all the
+ * storage objects allocated with `R.memory.alloc_storage` within the
BindingBlock, and then
+ * identify the static regions by propagating starting from the storage
objects.
+ *
+ * All the calls to `R.memory.alloc_storage` within the same BindingBlock are
grouped into a single
+ * new function. Each of the static regions are lifted to a new function.
+ *
+ * 2. Lift the regions identified in step 1 to a separate function and rewrite
the original function
+ * with `CUDAGraphRewriter`.
+ */
+
+#include <tvm/relax/backend.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/tir/expr.h>
+
+#include "../../support/arena.h"
+#include "../../support/ordered_set.h"
+#include "../../support/utils.h"
+
+namespace tvm {
+namespace relax {
+
+TVM_REGISTER_PASS_CONFIG_OPTION("relax.backend.use_cuda_graph", Bool);
+
+/*! \brief The rewriting plan of lifting a region for either allocation or
capturing for cuda graph
+ * execution
+ */
+struct LiftedFunctionRewritePlan {
+ // The lifted function for allocation or capturing
+ Function func;
+ // Whether the lifted function is for allocation or capturing
+ bool is_alloc;
+ // The binding var before which the lifted function should be invoked
+ const VarNode* launch_point;
+
+ // Variable remappings between the original function and the lifted function
+
+ // The bindings in the original function that are lifted
+ std::unordered_set<const VarNode*> lifted_bindings;
+ // The corresponding binding vars in the original function of the outputs of
the lifted function
+ std::vector<const VarNode*> outputs;
+ // The corresponding binding vars in the original function of the inputs of
the lifted function
+ std::vector<const VarNode*> inputs;
+};
+
+/*! \brief Builder of the lifted function for cuda graph capturing or
allocations */
+class FuncBuilder : public ExprMutator {
+ public:
+ /*!
+ * \brief Add a binding to the new function
+ * \param binding The binding to add
+ */
+ void AddBinding(const VarBindingNode* binding) {
bindings_.push_back(binding); }
+
+ /*!
+ * \brief Mark a variable as the input of the new function.
+ * \param var The variable to mark as input
+ */
+ void MarkInput(const VarNode* var) { inputs_.push_back(var); }
+ /*!
+ * \brief Mark a variable as the output of the new function. The variable
must be the LHS of an
+ * existing binding in the new function.
+ * \param var The variable to mark as output
+ */
+ void MarkOutput(const VarNode* var) { outputs_.push_back(var); }
+
+ /*! \brief Get the number of bindings in the new function */
+ auto size() const { return bindings_.size(); }
+
+ /*! \brief Build the new function */
+ Function Build() {
+ Array<Var> params;
+ // Set up the parameters
+ for (const auto* input : inputs_) {
+ auto new_var = Var(input->name_hint(),
Downcast<Optional<StructInfo>>(input->struct_info_));
+ var_remap_[input->vid] = new_var;
+ params.push_back(new_var);
+ }
+ // Emit the function body
+ builder_->BeginBindingBlock();
+ for (const auto* binding : bindings_) {
+ VisitBinding_(binding);
+ }
+ // Set up the outputs
+ Array<Expr> outputs;
+ for (const auto* var : outputs_) {
+ outputs.push_back(VisitExpr_(var));
+ }
+ auto output = builder_->Emit(Tuple(outputs));
+ auto block = builder_->EndBlock();
+ auto body = builder_->Normalize(SeqExpr({block}, output));
+ auto func = Function(params, body,
Downcast<StructInfo>(output->struct_info_.value()));
+ return func;
+ }
+
+ support::OrderedSet<const VarNode*> inputs_;
+ support::OrderedSet<const VarNode*> outputs_;
+ std::vector<const VarBindingNode*> bindings_;
+};
+
+/*!
+ * \brief The planner for rewriting the function to enable cuda graph
capturing.
+ */
+class CUDAGraphRewritePlanner : public ExprVisitor {
+ public:
+ explicit CUDAGraphRewritePlanner(const IRModule& mod) : mod_(mod) {}
+ std::vector<LiftedFunctionRewritePlan> Plan() {
+ for (const auto& pair : mod_->functions) {
+ const auto& func = pair.second;
+ if (func->IsInstance<FunctionNode>()) {
+ VisitExpr(func);
+ }
+ }
+ std::vector<LiftedFunctionRewritePlan> plans;
+
+ auto region_to_plan = [&](FuncBuilder* region, bool is_alloc) ->
LiftedFunctionRewritePlan {
+ LiftedFunctionRewritePlan plan;
+ plan.is_alloc = true;
+ plan.func = region->Build();
+ ICHECK(region->size());
+ plan.launch_point = region->bindings_.front()->var.get();
+ plan.is_alloc = is_alloc;
+ for (const auto* binding : region->bindings_) {
+ plan.lifted_bindings.insert(binding->var.get());
+ }
+ plan.inputs.assign(region->inputs_.begin(), region->inputs_.end());
+ plan.outputs.assign(region->outputs_.begin(), region->outputs_.end());
+ return plan;
+ };
+
+ for (auto* region : alloc_storages_) {
+ plans.push_back(region_to_plan(region, true));
+ }
+
+ for (auto* region : captured_regions_) {
+ plans.push_back(region_to_plan(region, false));
+ }
+ return plans;
+ }
+
+ /*!
+ *\brief Start a new static region. This method should be called when
encountering a
+ * CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on
static parameters.
+ */
+ void StartRegion() { current_.capture_builder = arena_.make<FuncBuilder>(); }
+
+ /*!
+ * \brief Finish a static region. This method should be called when
non-static bindings or
+ * unsupported operations are encountered.
+ */
+ void EndRegion() {
+ if (current_.capture_builder && current_.capture_builder->size()) {
+ captured_regions_.emplace_back(current_.capture_builder);
+ }
+ current_.capture_builder = nullptr;
+ }
+
+ void VisitBindingBlock_(const BindingBlockNode* binding_block) final {
+ Scope new_scope;
+ std::swap(new_scope, current_);
+ current_.alloc_storage_builder = arena_.make<FuncBuilder>();
+ for (const auto& binding : binding_block->bindings) {
+ VisitBinding(binding);
+ }
+ EndRegion();
+ if (current_.alloc_storage_builder->outputs_.size()) {
+ alloc_storages_.emplace_back(current_.alloc_storage_builder);
+ }
+ std::swap(new_scope, current_);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* call)
final {
+ static const auto& mem_alloc_storage_op =
Op::Get("relax.memory.alloc_storage");
+ static const auto& mem_kill_storage_op =
Op::Get("relax.memory.kill_storage");
+ static const auto& builtin_alloc_tensor_op =
Op::Get("relax.builtin.alloc_tensor");
+ static const auto& call_builtin_with_ctx_op =
Op::Get("relax.call_builtin_with_ctx");
+
+ if (call->op.same_as(mem_alloc_storage_op) &&
IsStaticAllocStorage(binding)) {
+ AddStaticBinding(binding, /*is_alloc_storage=*/true);
+ return;
+ } else if (call->op.same_as(mem_kill_storage_op) ||
call->op.same_as(builtin_alloc_tensor_op)) {
+ return;
+ }
+
+ const auto* call_gv = call->op.as<GlobalVarNode>();
+ bool call_prim_func =
+ call_gv ?
mod_->Lookup(GetRef<GlobalVar>(call_gv))->IsInstance<tir::PrimFuncNode>() :
false;
+ bool is_kernel_launch = [&]() {
+ if (call_prim_func) {
+ return true;
+ }
+ if (call->op.as<ExternFuncNode>()) {
+ return true;
+ }
+ if (const auto* op = call->op.as<OpNode>()) {
+ return !support::StartsWith(op->name, "relax.memory") &&
+ !support::StartsWith(op->name, "relax.builtin") &&
+ !GetRef<Op>(op).same_as(call_builtin_with_ctx_op);
+ }
+ return false;
+ }();
+
+ std::vector<const VarNode*> args;
+ bool is_all_static = IsStatic(call->args, &args);
+ if (call_gv != nullptr && !call_prim_func) {
+ // calls to other Relax functions are not allowed
+ is_all_static = false;
+ }
+ if (is_all_static) {
+ if (current_.capture_builder == nullptr && is_kernel_launch) {
+ StartRegion();
+ }
+ AddStaticBinding(binding, /*is_alloc_storage=*/false);
+ MarkAsFuncInput(args);
+ } else {
+ EndRegion();
+ }
+
+ MarkAsFuncOutput(args);
+ }
+
+ void MarkAsFuncInput(const std::vector<const VarNode*>& vars) {
+ if (current_.capture_builder == nullptr) {
+ return;
+ }
+ for (const VarNode* var : vars) {
+ auto it = binding_to_region_.find(var);
+ if (it == binding_to_region_.end() || it->second !=
current_.capture_builder) {
+ current_.capture_builder->MarkInput(var);
+ }
+ }
+ }
+
+ void MarkAsFuncOutput(const std::vector<const VarNode*>& vars) {
+ for (const VarNode* var : vars) {
+ if (auto it = binding_to_region_.find(var);
+ it != binding_to_region_.end() && it->second !=
current_.capture_builder) {
+ it->second->MarkOutput(var);
+ }
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final {
+ if (IsStatic(GetRef<Var>(var))) {
+ AddStaticBinding(binding, false);
+ MarkAsFuncInput({var});
+ } else {
+ EndRegion();
+ }
+ MarkAsFuncOutput({var});
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const ConstantNode*
constant) final {
+ AddStaticBinding(binding, false);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple)
final {
+ std::vector<const VarNode*> args;
+ if (IsStatic(tuple->fields, &args)) {
+ AddStaticBinding(binding, false);
+ MarkAsFuncInput(args);
+ } else {
+ EndRegion();
+ }
+ MarkAsFuncOutput(args);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode*
tuple_get_item) final {
+ const VarNode* tuple = tuple_get_item->tuple.as<VarNode>();
+ ICHECK(tuple);
+ if (IsStatic(tuple_get_item->tuple)) {
+ AddStaticBinding(binding, false);
+ MarkAsFuncInput({tuple});
+ } else {
+ EndRegion();
+ }
+ MarkAsFuncOutput({tuple});
+ }
+
+ bool IsStatic(const PrimExpr& expr,
+ [[maybe_unused]] std::vector<const VarNode*>* vars_collector =
nullptr) {
+ return expr->IsInstance<tir::IntImmNode>() ||
expr->IsInstance<tir::FloatImmNode>();
+ }
+
+ bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector
= nullptr) {
+ if (expr->IsInstance<ConstantNode>() ||
expr->IsInstance<DataTypeImmNode>()) {
+ return true;
+ }
+ if (const auto* prim_value = expr.as<PrimValueNode>()) {
+ return IsStatic(prim_value->value, vars_collector);
+ }
+ if (const auto* var = expr.as<VarNode>()) {
+ if (vars_collector != nullptr) {
+ vars_collector->push_back(var);
+ }
+ return static_bindings_.count(var);
+ }
+
+ if (const auto* shape = expr.as<ShapeExprNode>()) {
+ return IsStatic(shape->values, vars_collector);
+ }
+ if (const auto* tuple = expr.as<TupleNode>()) {
+ return IsStatic(tuple->fields, vars_collector);
+ }
+ return false;
+ }
+
+ template <typename T>
+ bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>*
vars_collector = nullptr) {
+ return std::all_of(exprs.begin(), exprs.end(),
+ [&](const T& expr) { return IsStatic(expr,
vars_collector); });
+ }
+
+ private:
+ bool IsStaticAllocStorage(const VarBindingNode* binding) {
+ // Check if the allocation has constant shape
+ const auto* alloc_storage_call = binding->value.as<CallNode>();
+ auto shape = Downcast<ShapeExpr>(alloc_storage_call->args[0]);
+ return std::all_of(shape->values.begin(), shape->values.end(),
+ [](const PrimExpr& expr) { return expr.as<IntImmNode>()
!= nullptr; });
+ }
+
+ /*!
+ * \brief Add a static bindings. This is used to mark the bindings that are
known to be static
+ * and further propagate to other bindings.
+ * \param binding the binding to add
+ * \param is_alloc_storage whether the binding is call to
R.memory.alloc_storage
+ */
+ void AddStaticBinding(const VarBindingNode* binding, bool is_alloc_storage) {
+ if (is_alloc_storage) {
+ current_.alloc_storage_builder->AddBinding(binding);
+ binding_to_region_[binding->var.get()] = current_.alloc_storage_builder;
+ } else if (current_.capture_builder != nullptr) {
+ // Add the binding if the capture builder exists. It is possible that
capture builder is
+ // null when it is not capturing. This is the case that there are not
yet any kernel launches
+ // encountered, in this case static bindings (e.g. binding of other
non-kernel-launch
+ // operations) are marked but are not lifted.
+ current_.capture_builder->AddBinding(binding);
+ binding_to_region_[binding->var.get()] = current_.capture_builder;
+ }
+ static_bindings_.emplace(binding->var.get(), GetRef<VarBinding>(binding));
+ }
+
+ /*! \brief The states of the current scope (the BindingBlock) which is a
pair of FuncBuilder.
+ * The FuncBuilder are initialized with nullptr, meaning the planner is
currently not doing any
+ * lifting. They are initialized lazily when a binding that can be lifted is
encountered.
+ * They are reset to nullptr when an unsupported operation is encountered.
+ */
+ struct Scope {
+ FuncBuilder* alloc_storage_builder = nullptr; // The builder for the
allocation function
+ FuncBuilder* capture_builder = nullptr; // The builder for the
capture function
+ };
+
+ // The IRModule
+ IRModule mod_;
+ // States of the current scope
+ Scope current_;
+ // All the static bindings
+ std::unordered_map<const VarNode*, VarBinding> static_bindings_;
+ // Binding to the FuncBuilder if the binding is lifted. This is used to
update the inputs/outputs
+ // of the lifted function when its binding is used outside.
+ std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
+ // The regions for capturing.
+ std::vector<FuncBuilder*> captured_regions_;
+ // The regions for allocation.
+ std::vector<FuncBuilder*> alloc_storages_;
+ // The arena.
+ support::Arena arena_;
+};
+
+/*! \brief The rewriter for CUDA graph */
+class CUDAGraphRewriter : public ExprMutator {
+ public:
+ explicit CUDAGraphRewriter(const IRModule& mod) : ExprMutator(mod) {}
+
+ IRModule Rewrite() {
+ CUDAGraphRewritePlanner planner(builder_->GetContextIRModule());
+ auto plans = planner.Plan();
+ for (const auto& plan : plans) {
+ subgraph_launches_[plan.launch_point] = plan;
+ }
+
+ for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
+ if (func->IsInstance<FunctionNode>()) {
+ auto new_func = Downcast<Function>(VisitExpr(func));
+ if (!new_func.same_as(func)) {
+ builder_->UpdateFunction(gv, new_func);
+ }
+ }
+ }
+ return builder_->GetContextIRModule();
+ }
+
+ void LaunchSubgraph(const VarBindingNode* op, const
LiftedFunctionRewritePlan& plan) {
+ static const auto& call_builtin_with_ctx_op =
Op::Get("relax.call_builtin_with_ctx");
+ static const auto& builtin_run_or_capture =
ExternFunc("vm.builtin.cuda_graph.run_or_capture");
+ static const auto& builtin_get_cached_alloc =
+ ExternFunc("vm.builtin.cuda_graph.get_cached_alloc");
+
+ Expr launch_subgraph;
+ auto gv_func =
+ builder_->AddFunction(plan.func, plan.is_alloc ? "cuda_graph_alloc" :
"cuda_graph_capture");
+ if (plan.is_alloc) {
+ ICHECK(plan.inputs.empty());
+ launch_subgraph =
+ Call(call_builtin_with_ctx_op,
+ {builtin_get_cached_alloc,
+ Tuple({gv_func, PrimValue(IntImm(DataType::Int(64),
index_alloc_++))})},
+ Attrs(), {plan.func->ret_struct_info});
+ } else {
+ Array<Expr> args;
+ for (const auto& arg : plan.inputs) {
+ args.push_back(VisitExpr_(arg));
+ }
+ launch_subgraph = Call(
+ call_builtin_with_ctx_op,
+ {builtin_run_or_capture,
+ Tuple({gv_func, Tuple(args), PrimValue(IntImm(DataType::Int(64),
index_capture_++))})},
+ Attrs(), {plan.func->ret_struct_info});
+ }
+ Expr ret_value = builder_->Emit(launch_subgraph);
+ for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) {
+ var_redef_[plan.outputs[i]] = TupleGetItem(ret_value, i);
+ }
+
+ lifted_bindings_.insert(plan.lifted_bindings.begin(),
plan.lifted_bindings.end());
+ }
+
+ void VisitBinding_(const VarBindingNode* op) final {
+ if (subgraph_launches_.count(op->var.get())) {
+ LaunchSubgraph(op, subgraph_launches_[op->var.get()]);
+ }
+ if (auto it = var_redef_.find(op->var.get()); it != var_redef_.end()) {
+ auto new_var = builder_->Emit(it->second, op->var->name_hint());
+ var_remap_[op->var->vid] = new_var;
+ return;
+ }
+ if (lifted_bindings_.count(op->var.get())) {
+ // The binding is lifted to the subgraph and will be removed from the
original function.
+ return;
+ }
+ ExprMutator::VisitBinding_(op);
+ }
+
+ std::unordered_map<const VarNode*, LiftedFunctionRewritePlan>
subgraph_launches_;
+ std::unordered_map<const VarNode*, Expr> var_redef_;
+ std::unordered_set<const VarNode*> lifted_bindings_;
+ int index_alloc_ = 0;
+ int index_capture_ = 0;
+};
+
+IRModule RewriteCUDAGraph(IRModule mod) {
+ CUDAGraphRewriter rewriter(mod);
+ mod = rewriter.Rewrite();
+ return mod;
+}
+
+namespace transform {
+
+Pass RewriteCUDAGraph() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
+ [=](IRModule m, PassContext pc) { return
::tvm::relax::RewriteCUDAGraph(std::move(m)); };
+ return CreateModulePass(pass_func, 0, "RewriteCUDAGraph", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.RewriteCUDAGraph").set_body_typed(RewriteCUDAGraph);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/support/ordered_set.h b/src/support/ordered_set.h
new file mode 100644
index 0000000000..8ba7708961
--- /dev/null
+++ b/src/support/ordered_set.h
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file support/ordered_set.h
+ * \brief An STL-like ordered set implementation.
+ */
+#ifndef TVM_SUPPORT_ORDERED_SET_H_
+#define TVM_SUPPORT_ORDERED_SET_H_
+
+#include <list>
+#include <unordered_map>
+
+namespace tvm {
+namespace support {
+
+template <typename T>
+class OrderedSet {
+ public:
+ void push_back(const T& t) {
+ if (!elem_to_iter_.count(t)) {
+ elements_.push_back(t);
+ elem_to_iter_[t] = std::prev(elements_.end());
+ }
+ }
+
+ void erase(const T& t) {
+ if (auto it = elem_to_iter_.find(t); it != elem_to_iter_.end()) {
+ elements_.erase(it->second);
+ elem_to_iter_.erase(it);
+ }
+ }
+
+ void clear() {
+ elements_.clear();
+ elem_to_iter_.clear();
+ }
+
+ auto begin() const { return elements_.begin(); }
+ auto end() const { return elements_.end(); }
+ auto size() const { return elements_.size(); }
+ auto empty() const { return elements_.empty(); }
+
+ private:
+ std::list<T> elements_;
+ std::unordered_map<T, typename std::list<T>::iterator> elem_to_iter_;
+};
+
+} // namespace support
+} // namespace tvm
+
+#endif // TVM_SUPPORT_ORDERED_SET_H_
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
new file mode 100644
index 0000000000..4fc4d6f4a1
--- /dev/null
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -0,0 +1,228 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import relax
+from tvm.script import tir as T, relax as R, ir as I
+import tvm.testing
+
+
+def test_rewrite_cuda_graph():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"),
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+ # function attr dict
+ T.func_attr({"tir.noalias": True, "global_symbol": "exp"})
+ # body
+ # with T.block("root")
+ for i0_i1_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
+ for i0_i1_fused_1 in T.thread_binding(T.int64(8),
thread="threadIdx.x"):
+ with T.block("compute"):
+ i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 *
T.int64(8) + i0_i1_fused_1) // T.int64(4))
+ i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 *
T.int64(8) + i0_i1_fused_1) % T.int64(4))
+ T.reads(rxplaceholder[i0, i1])
+ T.writes(compute[i0, i1])
+ compute[i0, i1] = T.exp(rxplaceholder[i0, i1],
dtype="float32")
+
+
+ @R.function
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ cls = Before
+ storage: R.Object = R.memory.alloc_storage(R.shape([32]), 0,
"global", "float32")
+ alloc: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32")
+ _1: R.Tuple = cls.exp(x, alloc)
+ storage1: R.Object = R.memory.alloc_storage(R.shape([32]), 0,
"global", "float32")
+ alloc1: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage1, 0, R.shape([2, 4]), "float32")
+ _2: R.Tuple = cls.exp(alloc, alloc1)
+ _3: R.Tuple = R.memory.kill_tensor(alloc)
+ alloc2: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32")
+ _4: R.Tuple = cls.exp(alloc1, alloc2)
+ _5: R.Tuple = R.memory.kill_tensor(alloc1)
+ alloc3: R.Tensor((2, 4), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2, 4]), "float32", 0)
+ _6 = cls.exp(alloc2, alloc3)
+ _7: R.Tuple = R.memory.kill_tensor(alloc2)
+ _8: R.Tuple = R.memory.kill_storage(storage)
+ _9: R.Tuple = R.memory.kill_storage(storage1)
+ return alloc3
+
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"),
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+ # function attr dict
+ T.func_attr({"tir.noalias": True, "global_symbol": "exp"})
+ # body
+ # with T.block("root")
+ for i0_i1_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
+ for i0_i1_fused_1 in T.thread_binding(T.int64(8),
thread="threadIdx.x"):
+ with T.block("compute"):
+ i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 *
T.int64(8) + i0_i1_fused_1) // T.int64(4))
+ i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 *
T.int64(8) + i0_i1_fused_1) % T.int64(4))
+ T.reads(rxplaceholder[i0, i1])
+ T.writes(compute[i0, i1])
+ compute[i0, i1] = T.exp(rxplaceholder[i0, i1],
dtype="float32")
+
+ @R.function
+ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+ gv: R.Object = R.memory.alloc_storage(R.shape([32]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ gv1: R.Object = R.memory.alloc_storage(R.shape([32]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ gv2: R.Tuple(R.Object, R.Object) = (gv, gv1)
+ return gv2
+
+ @R.function
+ def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"),
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) ->
R.Tuple(R.Tensor((2, 4), dtype="float32")):
+ cls = Expected
+ _2: R.Tuple = cls.exp(alloc, alloc1)
+ _3: R.Tuple = R.memory.kill_tensor(alloc)
+ alloc2: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]),
R.dtype("float32"))
+ _4: R.Tuple = cls.exp(alloc1, alloc2)
+ _5: R.Tuple = R.memory.kill_tensor(alloc1)
+ gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc2,)
+ return gv
+
+ @R.function
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ cls = Expected
+ gv: R.Tuple(R.Object, R.Object) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc",
(cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object,
R.Object),))
+ storage: R.Object = gv[0]
+ alloc: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]),
R.dtype("float32"))
+ _1: R.Tuple = cls.exp(x, alloc)
+ storage1: R.Object = gv[1]
+ alloc1: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]),
R.dtype("float32"))
+ gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture",
(cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)),
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
+ alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0]
+ alloc3: R.Tensor((2, 4), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
+ _6: R.Tuple = cls.exp(alloc2, alloc3)
+ _7: R.Tuple = R.memory.kill_tensor(alloc2)
+ _8: R.Tuple = R.memory.kill_storage(storage)
+ _9: R.Tuple = R.memory.kill_storage(storage1)
+ return alloc3
+ # fmt: on
+
+ after = relax.transform.RewriteCUDAGraph()(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_tuple():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"),
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+ # function attr dict
+ T.func_attr({"tir.noalias": True, "global_symbol": "exp"})
+ # body
+ # with T.block("root")
+ for i0_i1_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
+ for i0_i1_fused_1 in T.thread_binding(T.int64(8),
thread="threadIdx.x"):
+ with T.block("compute"):
+ i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 *
T.int64(8) + i0_i1_fused_1) // T.int64(4))
+ i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 *
T.int64(8) + i0_i1_fused_1) % T.int64(4))
+ T.reads(rxplaceholder[i0, i1])
+ T.writes(compute[i0, i1])
+ compute[i0, i1] = T.exp(rxplaceholder[i0, i1],
dtype="float32")
+
+
+ @R.function
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4),
dtype="float32"):
+ cls = Before
+ storage: R.Object = R.memory.alloc_storage(R.shape([32]), 0,
"global", "float32")
+ alloc: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32")
+ _: R.Tuple = cls.exp(x, alloc)
+ storage1: R.Object = R.memory.alloc_storage(R.shape([32]), 0,
"global", "float32")
+ alloc1: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage1, 0, R.shape([2, 4]), "float32")
+ _: R.Tuple = cls.exp(alloc, alloc1)
+ lv0 = (alloc1,)
+ lv1 = (lv0,)
+ lv2 = lv1[0]
+ lv3 = lv2[0]
+ alloc2: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32")
+ _1: R.Tuple = cls.exp(lv3, alloc2)
+ _2: R.Tuple = R.memory.kill_tensor(alloc)
+ _3: R.Tuple = R.memory.kill_tensor(alloc1)
+ alloc3: R.Tensor((2, 4), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
+ _4: R.Tuple = cls.exp(alloc2, alloc3)
+ _5: R.Tuple = R.memory.kill_tensor(alloc2)
+ _6: R.Tuple = R.memory.kill_storage(storage)
+ _7: R.Tuple = R.memory.kill_storage(storage1)
+ return alloc2
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"),
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+ T.func_attr({"global_symbol": "exp", "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i0_i1_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
+ for i0_i1_fused_1 in T.thread_binding(T.int64(8),
thread="threadIdx.x"):
+ with T.block("compute"):
+ i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 *
T.int64(8) + i0_i1_fused_1) // T.int64(4))
+ i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 *
T.int64(8) + i0_i1_fused_1) % T.int64(4))
+ T.reads(rxplaceholder[i0, i1])
+ T.writes(compute[i0, i1])
+ compute[i0, i1] = T.exp(rxplaceholder[i0, i1])
+
+ @R.function
+ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+ storage: R.Object = R.memory.alloc_storage(R.shape([32]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ storage1: R.Object = R.memory.alloc_storage(R.shape([32]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ gv: R.Tuple(R.Object, R.Object) = (storage, storage1)
+ return gv
+
+ @R.function
+ def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"),
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) ->
R.Tuple(R.Tensor((2, 4), dtype="float32")):
+ cls = Expected
+ _: R.Tuple = cls.exp(alloc, alloc1)
+ lv0: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc1,)
+ lv1: R.Tuple(R.Tuple(R.Tensor((2, 4), dtype="float32"))) = (lv0,)
+ lv2: R.Tuple(R.Tensor((2, 4), dtype="float32")) = lv1[0]
+ lv3: R.Tensor((2, 4), dtype="float32") = lv2[0]
+ alloc2: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]),
R.dtype("float32"))
+ _1: R.Tuple = cls.exp(lv3, alloc2)
+ _2: R.Tuple = R.memory.kill_tensor(alloc)
+ _3: R.Tuple = R.memory.kill_tensor(alloc1)
+ gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc2,)
+ return gv
+
+ @R.function
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4),
dtype="float32"):
+ cls = Expected
+ gv: R.Tuple(R.Object, R.Object) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc",
(cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object,
R.Object),))
+ storage: R.Object = gv[0]
+ alloc: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]),
R.dtype("float32"))
+ _: R.Tuple = cls.exp(x, alloc)
+ storage1: R.Object = gv[1]
+ alloc1: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]),
R.dtype("float32"))
+ gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture",
(cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)),
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
+ alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0]
+ alloc3: R.Tensor((2, 4), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
+ _4: R.Tuple = cls.exp(alloc2, alloc3)
+ _5: R.Tuple = R.memory.kill_tensor(alloc2)
+ _6: R.Tuple = R.memory.kill_storage(storage)
+ _7: R.Tuple = R.memory.kill_storage(storage1)
+ return alloc2
+ # fmt: on
+
+ after = relax.transform.RewriteCUDAGraph()(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()