This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bb1e10c81f [Unity] Add rewriting for CUDA graph capturing (#14513)
bb1e10c81f is described below

commit bb1e10c81f2eb9e5e67c3ecc6310c60d247ea283
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Apr 19 11:56:24 2023 -0700

    [Unity] Add rewriting for CUDA graph capturing (#14513)
    
    * [Unity] Add rewriting for CUDA graph capturing
    
    * address comments
    
    * Update python/tvm/relax/transform/transform.py
    
    Co-authored-by: Siyuan Feng <[email protected]>
    
    * address comments
    
    ---------
    
    Co-authored-by: Siyuan Feng <[email protected]>
---
 include/tvm/relax/transform.h                      |   7 +
 python/tvm/relax/transform/transform.py            |  12 +
 python/tvm/relax/vm_build.py                       |   4 +
 src/relax/transform/rewrite_cuda_graph.cc          | 512 +++++++++++++++++++++
 src/support/ordered_set.h                          |  68 +++
 .../relax/test_transform_rewrite_cuda_graph.py     | 228 +++++++++
 6 files changed, 831 insertions(+)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 9a9d1eb54e..27bd1bd702 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -483,6 +483,13 @@ TVM_DLL Pass DeadCodeElimination(Array<runtime::String> 
entry_functions);
  */
 TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype);
 
+/*!
+ * \brief Rewrite a Relax module for executing with CUDA graph. This pass 
identifies
+ * the regions that can be executed with CUDA graph and lifts them into new 
functions for runtime
+ * graph capturing.
+ */
+TVM_DLL Pass RewriteCUDAGraph();
+
 }  // namespace transform
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index b17f2fe62b..870b731883 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -964,6 +964,18 @@ def CombineParallelMatmul():
     return _ffi_api.CombineParallelMatmul()  # type: ignore
 
 
+def RewriteCUDAGraph() -> tvm.ir.transform.Pass:
+    """Rewrite a Relax module for executing with CUDA graph. This pass 
identifies the regions that
+    can be executed with CUDA graph and lifts them into new functions for 
runtime graph capturing.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+        The registered pass for rewriting cuda graph
+    """
+    return _ffi_api.RewriteCUDAGraph()  # type: ignore
+
+
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass."""
 
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index 0586bf9217..c89c64cd81 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -297,6 +297,10 @@ def build(
     passes.append(relax.transform.ToNonDataflow())
     passes.append(relax.transform.CallTIRRewrite())
     passes.append(relax.transform.StaticPlanBlockMemory())
+
+    if 
tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph", 
False):
+        passes.append(relax.transform.RewriteCUDAGraph())
+
     passes.append(relax.transform.VMBuiltinLower())
     passes.append(relax.transform.VMShapeLower())
     passes.append(relax.transform.AttachGlobalSymbol())
diff --git a/src/relax/transform/rewrite_cuda_graph.cc 
b/src/relax/transform/rewrite_cuda_graph.cc
new file mode 100644
index 0000000000..9621d9ff58
--- /dev/null
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -0,0 +1,512 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/transform/rewrite_cuda_graph.cc
+ * \brief Pass for transforming Relax module to execute with CUDA graph.
+ *
+ * CUDA graph provides a way to capture a sequence of CUDA kernel launches in 
the runtime and
+ * save them as a graph. The graph can be executed multiple times with less 
overhead than launching
+ * kernels individually. This pass rewrites the Relax module to execute with 
CUDA graph.
+ *
+ * The transformation is done in two steps:
+ *
+ * 1. Identify the regions that can be captured by CUDA graph and create the 
rewriting plan with
+ * `CUDAGraphRewritePlanner`.
+ *
+ * A region is a subgraph of the Relax function that are executed statically. 
A region is executed
+ * statically if 1) it only depends on the memory allocated internally in the 
Relax function with
+ * constant shapes, 2) it only contains kernel launches (there are no control 
flow).
+ *
+ * This transformation is expected to run after `StaticPlanBlockMemory`. After
+ * `StaticPlanBlockMemory`, all the tensors that can be statically allocated 
are allocated with
+ * `R.memory.alloc_storage` and `R.memory.alloc_tensor`, while other tensors 
will be allocated via
+ * `R.builtin.alloc_tensor`.
+ *
+ * `CUDAGraphRewritePlanner` is executed at the level of BindingBlock. It 
first identify all the
+ * storage objects allocated with `R.memory.alloc_storage` within the 
BindingBlock, and then
+ * identify the static regions by propagating starting from the storage 
objects.
+ *
+ * All the calls to `R.memory.alloc_storage` within the same BindingBlock are 
grouped into a single
+ * new function. Each of the static regions are lifted to a new function.
+ *
+ * 2. Lift the regions identified in step 1 to a separate function and rewrite 
the original function
+ * with `CUDAGraphRewriter`.
+ */
+
+#include <tvm/relax/backend.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/tir/expr.h>
+
+#include "../../support/arena.h"
+#include "../../support/ordered_set.h"
+#include "../../support/utils.h"
+
+namespace tvm {
+namespace relax {
+
+TVM_REGISTER_PASS_CONFIG_OPTION("relax.backend.use_cuda_graph", Bool);
+
+/*! \brief The rewriting plan of lifting a region for either allocation or 
capturing for cuda graph
+ * execution
+ */
+struct LiftedFunctionRewritePlan {
+  // The lifted function for allocation or capturing
+  Function func;
+  // Whether the lifted function is for allocation or capturing
+  bool is_alloc;
+  // The binding var before which the lifted function should be invoked
+  const VarNode* launch_point;
+
+  // Variable remappings between the original function and the lifted function
+
+  // The bindings in the original function that are lifted
+  std::unordered_set<const VarNode*> lifted_bindings;
+  // The corresponding binding vars in the original function of the outputs of 
the lifted function
+  std::vector<const VarNode*> outputs;
+  // The corresponding binding vars in the original function of the inputs of 
the lifted function
+  std::vector<const VarNode*> inputs;
+};
+
+/*! \brief Builder of the lifted function for cuda graph capturing or 
allocations */
+class FuncBuilder : public ExprMutator {
+ public:
+  /*!
+   * \brief Add a binding to the new function
+   * \param binding The binding to add
+   */
+  void AddBinding(const VarBindingNode* binding) { 
bindings_.push_back(binding); }
+
+  /*!
+   * \brief Mark a variable as the input of the new function.
+   * \param var The variable to mark as input
+   */
+  void MarkInput(const VarNode* var) { inputs_.push_back(var); }
+  /*!
+   * \brief Mark a variable as the output of the new function. The variable 
must be the LHS of an
+   * existing binding in the new function.
+   * \param var The variable to mark as output
+   */
+  void MarkOutput(const VarNode* var) { outputs_.push_back(var); }
+
+  /*! \brief Get the number of bindings in the new function */
+  auto size() const { return bindings_.size(); }
+
+  /*! \brief Build the new function */
+  Function Build() {
+    Array<Var> params;
+    // Set up the parameters
+    for (const auto* input : inputs_) {
+      auto new_var = Var(input->name_hint(), 
Downcast<Optional<StructInfo>>(input->struct_info_));
+      var_remap_[input->vid] = new_var;
+      params.push_back(new_var);
+    }
+    // Emit the function body
+    builder_->BeginBindingBlock();
+    for (const auto* binding : bindings_) {
+      VisitBinding_(binding);
+    }
+    // Set up the outputs
+    Array<Expr> outputs;
+    for (const auto* var : outputs_) {
+      outputs.push_back(VisitExpr_(var));
+    }
+    auto output = builder_->Emit(Tuple(outputs));
+    auto block = builder_->EndBlock();
+    auto body = builder_->Normalize(SeqExpr({block}, output));
+    auto func = Function(params, body, 
Downcast<StructInfo>(output->struct_info_.value()));
+    return func;
+  }
+
+  support::OrderedSet<const VarNode*> inputs_;
+  support::OrderedSet<const VarNode*> outputs_;
+  std::vector<const VarBindingNode*> bindings_;
+};
+
+/*!
+ * \brief The planner for rewriting the function to enable cuda graph 
capturing.
+ */
+class CUDAGraphRewritePlanner : public ExprVisitor {
+ public:
+  explicit CUDAGraphRewritePlanner(const IRModule& mod) : mod_(mod) {}
+  std::vector<LiftedFunctionRewritePlan> Plan() {
+    for (const auto& pair : mod_->functions) {
+      const auto& func = pair.second;
+      if (func->IsInstance<FunctionNode>()) {
+        VisitExpr(func);
+      }
+    }
+    std::vector<LiftedFunctionRewritePlan> plans;
+
+    auto region_to_plan = [&](FuncBuilder* region, bool is_alloc) -> 
LiftedFunctionRewritePlan {
+      LiftedFunctionRewritePlan plan;
+      plan.is_alloc = true;
+      plan.func = region->Build();
+      ICHECK(region->size());
+      plan.launch_point = region->bindings_.front()->var.get();
+      plan.is_alloc = is_alloc;
+      for (const auto* binding : region->bindings_) {
+        plan.lifted_bindings.insert(binding->var.get());
+      }
+      plan.inputs.assign(region->inputs_.begin(), region->inputs_.end());
+      plan.outputs.assign(region->outputs_.begin(), region->outputs_.end());
+      return plan;
+    };
+
+    for (auto* region : alloc_storages_) {
+      plans.push_back(region_to_plan(region, true));
+    }
+
+    for (auto* region : captured_regions_) {
+      plans.push_back(region_to_plan(region, false));
+    }
+    return plans;
+  }
+
+  /*!
+   *\brief Start a new static region. This method should be called when 
encountering a
+   * CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on 
static parameters.
+   */
+  void StartRegion() { current_.capture_builder = arena_.make<FuncBuilder>(); }
+
+  /*!
+   * \brief Finish a static region. This method should be called when 
non-static bindings or
+   * unsupported operations are encountered.
+   */
+  void EndRegion() {
+    if (current_.capture_builder && current_.capture_builder->size()) {
+      captured_regions_.emplace_back(current_.capture_builder);
+    }
+    current_.capture_builder = nullptr;
+  }
+
+  void VisitBindingBlock_(const BindingBlockNode* binding_block) final {
+    Scope new_scope;
+    std::swap(new_scope, current_);
+    current_.alloc_storage_builder = arena_.make<FuncBuilder>();
+    for (const auto& binding : binding_block->bindings) {
+      VisitBinding(binding);
+    }
+    EndRegion();
+    if (current_.alloc_storage_builder->outputs_.size()) {
+      alloc_storages_.emplace_back(current_.alloc_storage_builder);
+    }
+    std::swap(new_scope, current_);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* call) 
final {
+    static const auto& mem_alloc_storage_op = 
Op::Get("relax.memory.alloc_storage");
+    static const auto& mem_kill_storage_op = 
Op::Get("relax.memory.kill_storage");
+    static const auto& builtin_alloc_tensor_op = 
Op::Get("relax.builtin.alloc_tensor");
+    static const auto& call_builtin_with_ctx_op = 
Op::Get("relax.call_builtin_with_ctx");
+
+    if (call->op.same_as(mem_alloc_storage_op) && 
IsStaticAllocStorage(binding)) {
+      AddStaticBinding(binding, /*is_alloc_storage=*/true);
+      return;
+    } else if (call->op.same_as(mem_kill_storage_op) || 
call->op.same_as(builtin_alloc_tensor_op)) {
+      return;
+    }
+
+    const auto* call_gv = call->op.as<GlobalVarNode>();
+    bool call_prim_func =
+        call_gv ? 
mod_->Lookup(GetRef<GlobalVar>(call_gv))->IsInstance<tir::PrimFuncNode>() : 
false;
+    bool is_kernel_launch = [&]() {
+      if (call_prim_func) {
+        return true;
+      }
+      if (call->op.as<ExternFuncNode>()) {
+        return true;
+      }
+      if (const auto* op = call->op.as<OpNode>()) {
+        return !support::StartsWith(op->name, "relax.memory") &&
+               !support::StartsWith(op->name, "relax.builtin") &&
+               !GetRef<Op>(op).same_as(call_builtin_with_ctx_op);
+      }
+      return false;
+    }();
+
+    std::vector<const VarNode*> args;
+    bool is_all_static = IsStatic(call->args, &args);
+    if (call_gv != nullptr && !call_prim_func) {
+      // calls to other Relax functions are not allowed
+      is_all_static = false;
+    }
+    if (is_all_static) {
+      if (current_.capture_builder == nullptr && is_kernel_launch) {
+        StartRegion();
+      }
+      AddStaticBinding(binding, /*is_alloc_storage=*/false);
+      MarkAsFuncInput(args);
+    } else {
+      EndRegion();
+    }
+
+    MarkAsFuncOutput(args);
+  }
+
+  void MarkAsFuncInput(const std::vector<const VarNode*>& vars) {
+    if (current_.capture_builder == nullptr) {
+      return;
+    }
+    for (const VarNode* var : vars) {
+      auto it = binding_to_region_.find(var);
+      if (it == binding_to_region_.end() || it->second != 
current_.capture_builder) {
+        current_.capture_builder->MarkInput(var);
+      }
+    }
+  }
+
+  void MarkAsFuncOutput(const std::vector<const VarNode*>& vars) {
+    for (const VarNode* var : vars) {
+      if (auto it = binding_to_region_.find(var);
+          it != binding_to_region_.end() && it->second != 
current_.capture_builder) {
+        it->second->MarkOutput(var);
+      }
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final {
+    if (IsStatic(GetRef<Var>(var))) {
+      AddStaticBinding(binding, false);
+      MarkAsFuncInput({var});
+    } else {
+      EndRegion();
+    }
+    MarkAsFuncOutput({var});
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const ConstantNode* 
constant) final {
+    AddStaticBinding(binding, false);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) 
final {
+    std::vector<const VarNode*> args;
+    if (IsStatic(tuple->fields, &args)) {
+      AddStaticBinding(binding, false);
+      MarkAsFuncInput(args);
+    } else {
+      EndRegion();
+    }
+    MarkAsFuncOutput(args);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* 
tuple_get_item) final {
+    const VarNode* tuple = tuple_get_item->tuple.as<VarNode>();
+    ICHECK(tuple);
+    if (IsStatic(tuple_get_item->tuple)) {
+      AddStaticBinding(binding, false);
+      MarkAsFuncInput({tuple});
+    } else {
+      EndRegion();
+    }
+    MarkAsFuncOutput({tuple});
+  }
+
+  bool IsStatic(const PrimExpr& expr,
+                [[maybe_unused]] std::vector<const VarNode*>* vars_collector = 
nullptr) {
+    return expr->IsInstance<tir::IntImmNode>() || 
expr->IsInstance<tir::FloatImmNode>();
+  }
+
+  bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector 
= nullptr) {
+    if (expr->IsInstance<ConstantNode>() || 
expr->IsInstance<DataTypeImmNode>()) {
+      return true;
+    }
+    if (const auto* prim_value = expr.as<PrimValueNode>()) {
+      return IsStatic(prim_value->value, vars_collector);
+    }
+    if (const auto* var = expr.as<VarNode>()) {
+      if (vars_collector != nullptr) {
+        vars_collector->push_back(var);
+      }
+      return static_bindings_.count(var);
+    }
+
+    if (const auto* shape = expr.as<ShapeExprNode>()) {
+      return IsStatic(shape->values, vars_collector);
+    }
+    if (const auto* tuple = expr.as<TupleNode>()) {
+      return IsStatic(tuple->fields, vars_collector);
+    }
+    return false;
+  }
+
+  template <typename T>
+  bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>* 
vars_collector = nullptr) {
+    return std::all_of(exprs.begin(), exprs.end(),
+                       [&](const T& expr) { return IsStatic(expr, 
vars_collector); });
+  }
+
+ private:
+  bool IsStaticAllocStorage(const VarBindingNode* binding) {
+    // Check if the allocation has constant shape
+    const auto* alloc_storage_call = binding->value.as<CallNode>();
+    auto shape = Downcast<ShapeExpr>(alloc_storage_call->args[0]);
+    return std::all_of(shape->values.begin(), shape->values.end(),
+                       [](const PrimExpr& expr) { return expr.as<IntImmNode>() 
!= nullptr; });
+  }
+
+  /*!
+   * \brief Add a static bindings. This is used to mark the bindings that are 
known to be static
+   * and further propagate to other bindings.
+   * \param binding the binding to add
+   * \param is_alloc_storage whether the binding is call to 
R.memory.alloc_storage
+   */
+  void AddStaticBinding(const VarBindingNode* binding, bool is_alloc_storage) {
+    if (is_alloc_storage) {
+      current_.alloc_storage_builder->AddBinding(binding);
+      binding_to_region_[binding->var.get()] = current_.alloc_storage_builder;
+    } else if (current_.capture_builder != nullptr) {
+      // Add the binding if the capture builder exists. It is possible that 
capture builder is
+      // null when it is not capturing. This is the case that there are not 
yet any kernel launches
+      // encountered, in this case static bindings (e.g. binding of other 
non-kernel-launch
+      // operations) are marked but are not lifted.
+      current_.capture_builder->AddBinding(binding);
+      binding_to_region_[binding->var.get()] = current_.capture_builder;
+    }
+    static_bindings_.emplace(binding->var.get(), GetRef<VarBinding>(binding));
+  }
+
+  /*! \brief The states of the current scope (the BindingBlock) which is a 
pair of FuncBuilder.
+   * The FuncBuilder are initialized with nullptr, meaning the planner is 
currently not doing any
+   * lifting. They are initialized lazily when a binding that can be lifted is 
encountered.
+   * They are reset to nullptr when an unsupported operation is encountered.
+   */
+  struct Scope {
+    FuncBuilder* alloc_storage_builder = nullptr;  // The builder for the 
allocation function
+    FuncBuilder* capture_builder = nullptr;        // The builder for the 
capture function
+  };
+
+  // The IRModule
+  IRModule mod_;
+  // States of the current scope
+  Scope current_;
+  // All the static bindings
+  std::unordered_map<const VarNode*, VarBinding> static_bindings_;
+  // Binding to the FuncBuilder if the binding is lifted. This is used to 
update the inputs/outputs
+  // of the lifted function when its binding is used outside.
+  std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
+  // The regions for capturing.
+  std::vector<FuncBuilder*> captured_regions_;
+  // The regions for allocation.
+  std::vector<FuncBuilder*> alloc_storages_;
+  // The arena.
+  support::Arena arena_;
+};
+
+/*! \brief The rewriter for CUDA graph */
+class CUDAGraphRewriter : public ExprMutator {
+ public:
+  explicit CUDAGraphRewriter(const IRModule& mod) : ExprMutator(mod) {}
+
+  IRModule Rewrite() {
+    CUDAGraphRewritePlanner planner(builder_->GetContextIRModule());
+    auto plans = planner.Plan();
+    for (const auto& plan : plans) {
+      subgraph_launches_[plan.launch_point] = plan;
+    }
+
+    for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
+      if (func->IsInstance<FunctionNode>()) {
+        auto new_func = Downcast<Function>(VisitExpr(func));
+        if (!new_func.same_as(func)) {
+          builder_->UpdateFunction(gv, new_func);
+        }
+      }
+    }
+    return builder_->GetContextIRModule();
+  }
+
+  void LaunchSubgraph(const VarBindingNode* op, const 
LiftedFunctionRewritePlan& plan) {
+    static const auto& call_builtin_with_ctx_op = 
Op::Get("relax.call_builtin_with_ctx");
+    static const auto& builtin_run_or_capture = 
ExternFunc("vm.builtin.cuda_graph.run_or_capture");
+    static const auto& builtin_get_cached_alloc =
+        ExternFunc("vm.builtin.cuda_graph.get_cached_alloc");
+
+    Expr launch_subgraph;
+    auto gv_func =
+        builder_->AddFunction(plan.func, plan.is_alloc ? "cuda_graph_alloc" : 
"cuda_graph_capture");
+    if (plan.is_alloc) {
+      ICHECK(plan.inputs.empty());
+      launch_subgraph =
+          Call(call_builtin_with_ctx_op,
+               {builtin_get_cached_alloc,
+                Tuple({gv_func, PrimValue(IntImm(DataType::Int(64), 
index_alloc_++))})},
+               Attrs(), {plan.func->ret_struct_info});
+    } else {
+      Array<Expr> args;
+      for (const auto& arg : plan.inputs) {
+        args.push_back(VisitExpr_(arg));
+      }
+      launch_subgraph = Call(
+          call_builtin_with_ctx_op,
+          {builtin_run_or_capture,
+           Tuple({gv_func, Tuple(args), PrimValue(IntImm(DataType::Int(64), 
index_capture_++))})},
+          Attrs(), {plan.func->ret_struct_info});
+    }
+    Expr ret_value = builder_->Emit(launch_subgraph);
+    for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) {
+      var_redef_[plan.outputs[i]] = TupleGetItem(ret_value, i);
+    }
+
+    lifted_bindings_.insert(plan.lifted_bindings.begin(), 
plan.lifted_bindings.end());
+  }
+
+  void VisitBinding_(const VarBindingNode* op) final {
+    if (subgraph_launches_.count(op->var.get())) {
+      LaunchSubgraph(op, subgraph_launches_[op->var.get()]);
+    }
+    if (auto it = var_redef_.find(op->var.get()); it != var_redef_.end()) {
+      auto new_var = builder_->Emit(it->second, op->var->name_hint());
+      var_remap_[op->var->vid] = new_var;
+      return;
+    }
+    if (lifted_bindings_.count(op->var.get())) {
+      // The binding is lifted to the subgraph and will be removed from the 
original function.
+      return;
+    }
+    ExprMutator::VisitBinding_(op);
+  }
+
+  std::unordered_map<const VarNode*, LiftedFunctionRewritePlan> 
subgraph_launches_;
+  std::unordered_map<const VarNode*, Expr> var_redef_;
+  std::unordered_set<const VarNode*> lifted_bindings_;
+  int index_alloc_ = 0;
+  int index_capture_ = 0;
+};
+
+IRModule RewriteCUDAGraph(IRModule mod) {
+  CUDAGraphRewriter rewriter(mod);
+  mod = rewriter.Rewrite();
+  return mod;
+}
+
+namespace transform {
+
+Pass RewriteCUDAGraph() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
+      [=](IRModule m, PassContext pc) { return 
::tvm::relax::RewriteCUDAGraph(std::move(m)); };
+  return CreateModulePass(pass_func, 0, "RewriteCUDAGraph", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.RewriteCUDAGraph").set_body_typed(RewriteCUDAGraph);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/support/ordered_set.h b/src/support/ordered_set.h
new file mode 100644
index 0000000000..8ba7708961
--- /dev/null
+++ b/src/support/ordered_set.h
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file support/ordered_set.h
+ * \brief An STL-like ordered set implementation.
+ */
+#ifndef TVM_SUPPORT_ORDERED_SET_H_
+#define TVM_SUPPORT_ORDERED_SET_H_
+
+#include <list>
+#include <unordered_map>
+
+namespace tvm {
+namespace support {
+
+template <typename T>
+class OrderedSet {
+ public:
+  void push_back(const T& t) {
+    if (!elem_to_iter_.count(t)) {
+      elements_.push_back(t);
+      elem_to_iter_[t] = std::prev(elements_.end());
+    }
+  }
+
+  void erase(const T& t) {
+    if (auto it = elem_to_iter_.find(t); it != elem_to_iter_.end()) {
+      elements_.erase(it->second);
+      elem_to_iter_.erase(it);
+    }
+  }
+
+  void clear() {
+    elements_.clear();
+    elem_to_iter_.clear();
+  }
+
+  auto begin() const { return elements_.begin(); }
+  auto end() const { return elements_.end(); }
+  auto size() const { return elements_.size(); }
+  auto empty() const { return elements_.empty(); }
+
+ private:
+  std::list<T> elements_;
+  std::unordered_map<T, typename std::list<T>::iterator> elem_to_iter_;
+};
+
+}  // namespace support
+}  // namespace tvm
+
+#endif  // TVM_SUPPORT_ORDERED_SET_H_
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py 
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
new file mode 100644
index 0000000000..4fc4d6f4a1
--- /dev/null
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -0,0 +1,228 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import relax
+from tvm.script import tir as T, relax as R, ir as I
+import tvm.testing
+
+
+def test_rewrite_cuda_graph():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), 
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+            # function attr dict
+            T.func_attr({"tir.noalias": True, "global_symbol": "exp"})
+            # body
+            # with T.block("root")
+            for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
+                for i0_i1_fused_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.x"):
+                    with T.block("compute"):
+                        i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * 
T.int64(8) + i0_i1_fused_1) // T.int64(4))
+                        i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * 
T.int64(8) + i0_i1_fused_1) % T.int64(4))
+                        T.reads(rxplaceholder[i0, i1])
+                        T.writes(compute[i0, i1])
+                        compute[i0, i1] = T.exp(rxplaceholder[i0, i1], 
dtype="float32")
+
+
+        @R.function
+        def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            cls = Before
+            storage: R.Object = R.memory.alloc_storage(R.shape([32]), 0, 
"global", "float32")
+            alloc: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32")
+            _1: R.Tuple = cls.exp(x, alloc)
+            storage1: R.Object = R.memory.alloc_storage(R.shape([32]), 0, 
"global", "float32")
+            alloc1: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage1, 0, R.shape([2, 4]), "float32")
+            _2: R.Tuple = cls.exp(alloc, alloc1)
+            _3: R.Tuple = R.memory.kill_tensor(alloc)
+            alloc2: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32")
+            _4: R.Tuple = cls.exp(alloc1, alloc2)
+            _5: R.Tuple = R.memory.kill_tensor(alloc1)
+            alloc3: R.Tensor((2, 4), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, 4]), "float32", 0)
+            _6 = cls.exp(alloc2, alloc3)
+            _7: R.Tuple = R.memory.kill_tensor(alloc2)
+            _8: R.Tuple = R.memory.kill_storage(storage)
+            _9: R.Tuple = R.memory.kill_storage(storage1)
+            return alloc3
+
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), 
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+            # function attr dict
+            T.func_attr({"tir.noalias": True, "global_symbol": "exp"})
+            # body
+            # with T.block("root")
+            for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
+                for i0_i1_fused_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.x"):
+                    with T.block("compute"):
+                        i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * 
T.int64(8) + i0_i1_fused_1) // T.int64(4))
+                        i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * 
T.int64(8) + i0_i1_fused_1) % T.int64(4))
+                        T.reads(rxplaceholder[i0, i1])
+                        T.writes(compute[i0, i1])
+                        compute[i0, i1] = T.exp(rxplaceholder[i0, i1], 
dtype="float32")
+
+        @R.function
+        def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+            gv: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            gv1: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            gv2: R.Tuple(R.Object, R.Object) = (gv, gv1)
+            return gv2
+
+        @R.function
+        def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> 
R.Tuple(R.Tensor((2, 4), dtype="float32")):
+            cls = Expected
+            _2: R.Tuple = cls.exp(alloc, alloc1)
+            _3: R.Tuple = R.memory.kill_tensor(alloc)
+            alloc2: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
+            _4: R.Tuple = cls.exp(alloc1, alloc2)
+            _5: R.Tuple = R.memory.kill_tensor(alloc1)
+            gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc2,)
+            return gv
+
+        @R.function
+        def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            cls = Expected
+            gv: R.Tuple(R.Object, R.Object) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", 
(cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, 
R.Object),))
+            storage: R.Object = gv[0]
+            alloc: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
+            _1: R.Tuple = cls.exp(x, alloc)
+            storage1: R.Object = gv[1]
+            alloc1: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
+            gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
+            alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0]
+            alloc3: R.Tensor((2, 4), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
+            _6: R.Tuple = cls.exp(alloc2, alloc3)
+            _7: R.Tuple = R.memory.kill_tensor(alloc2)
+            _8: R.Tuple = R.memory.kill_storage(storage)
+            _9: R.Tuple = R.memory.kill_storage(storage1)
+            return alloc3
+    # fmt: on
+
+    after = relax.transform.RewriteCUDAGraph()(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_tuple():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), 
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+            # function attr dict
+            T.func_attr({"tir.noalias": True, "global_symbol": "exp"})
+            # body
+            # with T.block("root")
+            for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
+                for i0_i1_fused_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.x"):
+                    with T.block("compute"):
+                        i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * 
T.int64(8) + i0_i1_fused_1) // T.int64(4))
+                        i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * 
T.int64(8) + i0_i1_fused_1) % T.int64(4))
+                        T.reads(rxplaceholder[i0, i1])
+                        T.writes(compute[i0, i1])
+                        compute[i0, i1] = T.exp(rxplaceholder[i0, i1], 
dtype="float32")
+
+
+        @R.function
+        def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), 
dtype="float32"):
+            cls = Before
+            storage: R.Object = R.memory.alloc_storage(R.shape([32]), 0, 
"global", "float32")
+            alloc: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32")
+            _: R.Tuple = cls.exp(x, alloc)
+            storage1: R.Object = R.memory.alloc_storage(R.shape([32]), 0, 
"global", "float32")
+            alloc1: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage1, 0, R.shape([2, 4]), "float32")
+            _: R.Tuple = cls.exp(alloc, alloc1)
+            lv0 = (alloc1,)
+            lv1 = (lv0,)
+            lv2 = lv1[0]
+            lv3 = lv2[0]
+            alloc2: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32")
+            _1: R.Tuple = cls.exp(lv3, alloc2)
+            _2: R.Tuple = R.memory.kill_tensor(alloc)
+            _3: R.Tuple = R.memory.kill_tensor(alloc1)
+            alloc3: R.Tensor((2, 4), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
+            _4: R.Tuple = cls.exp(alloc2, alloc3)
+            _5: R.Tuple = R.memory.kill_tensor(alloc2)
+            _6: R.Tuple = R.memory.kill_storage(storage)
+            _7: R.Tuple = R.memory.kill_storage(storage1)
+            return alloc2
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), 
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+            T.func_attr({"global_symbol": "exp", "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
+                for i0_i1_fused_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.x"):
+                    with T.block("compute"):
+                        i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * 
T.int64(8) + i0_i1_fused_1) // T.int64(4))
+                        i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * 
T.int64(8) + i0_i1_fused_1) % T.int64(4))
+                        T.reads(rxplaceholder[i0, i1])
+                        T.writes(compute[i0, i1])
+                        compute[i0, i1] = T.exp(rxplaceholder[i0, i1])
+
+        @R.function
+        def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+            storage: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            storage1: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            gv: R.Tuple(R.Object, R.Object) = (storage, storage1)
+            return gv
+
+        @R.function
+        def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> 
R.Tuple(R.Tensor((2, 4), dtype="float32")):
+            cls = Expected
+            _: R.Tuple = cls.exp(alloc, alloc1)
+            lv0: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc1,)
+            lv1: R.Tuple(R.Tuple(R.Tensor((2, 4), dtype="float32"))) = (lv0,)
+            lv2: R.Tuple(R.Tensor((2, 4), dtype="float32")) = lv1[0]
+            lv3: R.Tensor((2, 4), dtype="float32") = lv2[0]
+            alloc2: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
+            _1: R.Tuple = cls.exp(lv3, alloc2)
+            _2: R.Tuple = R.memory.kill_tensor(alloc)
+            _3: R.Tuple = R.memory.kill_tensor(alloc1)
+            gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc2,)
+            return gv
+
+        @R.function
+        def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), 
dtype="float32"):
+            cls = Expected
+            gv: R.Tuple(R.Object, R.Object) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", 
(cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, 
R.Object),))
+            storage: R.Object = gv[0]
+            alloc: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
+            _: R.Tuple = cls.exp(x, alloc)
+            storage1: R.Object = gv[1]
+            alloc1: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
+            gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
+            alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0]
+            alloc3: R.Tensor((2, 4), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
+            _4: R.Tuple = cls.exp(alloc2, alloc3)
+            _5: R.Tuple = R.memory.kill_tensor(alloc2)
+            _6: R.Tuple = R.memory.kill_storage(storage)
+            _7: R.Tuple = R.memory.kill_storage(storage1)
+            return alloc2
+    # fmt: on
+
+    after = relax.transform.RewriteCUDAGraph()(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to