This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f83a32906f [Relax] Share storage allocs among functions after cuda
graph rewriting (#16830)
f83a32906f is described below
commit f83a32906f9d3765946db0b9bdc31e4eef5072b3
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Apr 1 21:09:55 2024 -0700
[Relax] Share storage allocs among functions after cuda graph rewriting
(#16830)
---
src/relax/transform/rewrite_cuda_graph.cc | 386 ++++++++++++++++-----
.../relax/test_transform_rewrite_cuda_graph.py | 241 ++++++++++++-
2 files changed, 518 insertions(+), 109 deletions(-)
diff --git a/src/relax/transform/rewrite_cuda_graph.cc
b/src/relax/transform/rewrite_cuda_graph.cc
index 25b229ebce..d0e20ffd76 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -49,17 +49,19 @@
* 2. Lift the regions identified in step 1 to a separate function and rewrite
the original function
* with `CUDAGraphRewriter`.
*/
-
+#include <tvm/relax/analysis.h>
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/stmt_functor.h>
+#include <unordered_map>
+#include <vector>
+
#include "../../support/arena.h"
#include "../../support/ordered_set.h"
#include "../../support/utils.h"
-
namespace tvm {
namespace relax {
@@ -79,9 +81,10 @@ struct LiftedFunctionRewritePlan {
// Variable remappings between the original function and the lifted function
// The bindings in the original function that are lifted
- std::unordered_set<const VarNode*> lifted_bindings;
+ std::vector<const VarBindingNode*> lifted_bindings;
// The corresponding binding vars in the original function of the outputs of
the lifted function
- std::vector<const VarNode*> outputs;
+ // to the index of the element in the output tuple of the lifted function.
+ std::unordered_map<const VarNode*, int> outputs;
// The corresponding binding vars in the original function of the inputs of
the lifted function
std::vector<const VarNode*> inputs;
// The tir vars in the original function that are propagated to the lifted
function
@@ -170,13 +173,68 @@ class FuncBuilder : public ExprMutator {
Map<tir::Var, PrimExpr> tir_var_remap_;
};
+// Collect the storage objects that are used as the function output
+class OutputStorageCollector : public ExprVisitor {
+ public:
+ static std::unordered_set<const VarNode*> Collect(const Function& func) {
+ OutputStorageCollector collector;
+ collector.VisitExpr(func);
+ return std::move(collector.output_storages_);
+ }
+
+ private:
+ void VisitExpr_(const SeqExprNode* seq_expr) final {
+ auto output_vars = FreeVars(seq_expr->body);
+ for (const auto& var : output_vars) {
+ output_vars_.insert(var.get());
+ }
+ // Visit the blocks in reverse order for backward propagation
+ for (auto it = seq_expr->blocks.rbegin(); it != seq_expr->blocks.rend();
++it) {
+ VisitBindingBlock(*it);
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* call)
final {
+ static const auto& mem_alloc_tensor_op =
Op::Get("relax.memory.alloc_tensor");
+ if (output_vars_.count(binding->var.get()) &&
call->op.same_as(mem_alloc_tensor_op)) {
+ output_storages_.insert(call->args[0].as<VarNode>());
+ }
+ }
+
+ void VisitBindingBlock_(const BindingBlockNode* binding_block) override {
+ // Visit the bindings in reverse order
+ for (auto it = binding_block->bindings.rbegin(); it !=
binding_block->bindings.rend(); ++it) {
+ VisitBinding(*it);
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final {
+ if (output_vars_.count(binding->var.get())) {
+ output_vars_.insert(var);
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple)
final {
+ if (output_vars_.count(binding->var.get())) {
+ for (const auto& field : tuple->fields) {
+ output_vars_.insert(field.as<VarNode>());
+ }
+ }
+ }
+
+ std::unordered_set<const VarNode*> output_storages_;
+ std::unordered_set<const VarNode*> output_vars_;
+};
+
/*!
* \brief The planner for rewriting the function to enable cuda graph
capturing.
*/
class CUDAGraphRewritePlanner : public ExprVisitor {
public:
- explicit CUDAGraphRewritePlanner(const IRModule& mod) : mod_(mod) {}
- std::vector<LiftedFunctionRewritePlan> Plan() {
+ explicit CUDAGraphRewritePlanner(const IRModule& mod, support::Arena* arena)
+ : mod_(mod), arena_(arena) {}
+ std::pair<std::vector<LiftedFunctionRewritePlan*>,
std::vector<LiftedFunctionRewritePlan*>>
+ Plan() {
for (const auto& pair : mod_->functions) {
if (pair.second->IsInstance<FunctionNode>()) {
// If a function has the num_input attribute, the last
func->params.size() - num_inputs
@@ -188,41 +246,41 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
}
}
CollectSymbolicVarHints(func);
+ disabled_storage_vars_ = OutputStorageCollector::Collect(func);
VisitExpr(func);
}
}
- std::vector<LiftedFunctionRewritePlan> plans;
-
- auto region_to_plan = [&](FuncBuilder* region, bool is_alloc) ->
LiftedFunctionRewritePlan {
- LiftedFunctionRewritePlan plan;
- plan.is_alloc = true;
- plan.func = region->Build();
+ auto region_to_plan = [&](FuncBuilder* region, bool is_alloc) ->
LiftedFunctionRewritePlan* {
+ auto* plan = arena_->make<LiftedFunctionRewritePlan>();
+ plan->is_alloc = true;
+ plan->func = region->Build();
ICHECK(region->size());
- plan.launch_point = region->bindings_.front()->var.get();
- plan.is_alloc = is_alloc;
- for (const auto* binding : region->bindings_) {
- plan.lifted_bindings.insert(binding->var.get());
- }
+ plan->launch_point = region->bindings_.front()->var.get();
+ plan->is_alloc = is_alloc;
+ plan->lifted_bindings = std::move(region->bindings_);
if (region->shape_expr_inputs_.size()) {
Array<PrimExpr> tir_vars;
for (const auto* var : region->shape_expr_inputs_) {
tir_vars.push_back(GetRef<PrimExpr>(var));
}
- plan.propogated_tir_vars = ShapeExpr(tir_vars);
+ plan->propogated_tir_vars = ShapeExpr(tir_vars);
+ }
+ plan->inputs.assign(region->inputs_.begin(), region->inputs_.end());
+ for (const auto* var : region->outputs_) {
+ plan->outputs[var] = plan->outputs.size();
}
- plan.inputs.assign(region->inputs_.begin(), region->inputs_.end());
- plan.outputs.assign(region->outputs_.begin(), region->outputs_.end());
return plan;
};
- for (auto* region : alloc_storages_) {
- plans.push_back(region_to_plan(region, /*is_alloc=*/true));
- }
-
- for (auto* region : captured_regions_) {
- plans.push_back(region_to_plan(region, /*is_alloc=*/false));
- }
- return plans;
+ std::vector<LiftedFunctionRewritePlan*> alloc_plans, capture_plans;
+ alloc_plans.reserve(alloc_storages_.size());
+ capture_plans.reserve(captured_regions_.size());
+ std::transform(alloc_storages_.begin(), alloc_storages_.end(),
std::back_inserter(alloc_plans),
+ [&](FuncBuilder* region) { return region_to_plan(region,
/*is_alloc=*/true); });
+ std::transform(captured_regions_.begin(), captured_regions_.end(),
+ std::back_inserter(capture_plans),
+ [&](FuncBuilder* region) { return region_to_plan(region,
/*is_alloc=*/false); });
+ return {std::move(alloc_plans), std::move(capture_plans)};
}
/*!
@@ -241,31 +299,36 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
*\brief Start a new static region. This method should be called when
encountering a
* CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on
static parameters.
*/
- void StartRegion() { current_.capture_builder = arena_.make<FuncBuilder>(); }
+ void StartRegion() { current_block_scope_.capture_builder =
arena_->make<FuncBuilder>(); }
/*!
* \brief Finish a static region. This method should be called when
non-static bindings or
* unsupported operations are encountered.
*/
void EndRegion() {
- if (current_.capture_builder && current_.capture_builder->size()) {
- captured_regions_.emplace_back(current_.capture_builder);
+ if (current_block_scope_.capture_builder &&
current_block_scope_.capture_builder->size()) {
+ captured_regions_.emplace_back(current_block_scope_.capture_builder);
}
- current_.capture_builder = nullptr;
+ current_block_scope_.capture_builder = nullptr;
+ }
+
+ void VisitExpr_(const FunctionNode* func) final {
+ current_function_scope_.alloc_storage_builder =
arena_->make<FuncBuilder>();
+ ExprVisitor::VisitExpr_(func);
+ if (current_function_scope_.alloc_storage_builder->outputs_.size()) {
+
alloc_storages_.emplace_back(current_function_scope_.alloc_storage_builder);
+ }
+ current_function_scope_.alloc_storage_builder = nullptr;
}
void VisitBindingBlock_(const BindingBlockNode* binding_block) final {
- Scope new_scope;
- std::swap(new_scope, current_);
- current_.alloc_storage_builder = arena_.make<FuncBuilder>();
+ BindingBlockScope new_scope;
+ std::swap(new_scope, current_block_scope_);
for (const auto& binding : binding_block->bindings) {
VisitBinding(binding);
}
EndRegion();
- if (current_.alloc_storage_builder->outputs_.size()) {
- alloc_storages_.emplace_back(current_.alloc_storage_builder);
- }
- std::swap(new_scope, current_);
+ std::swap(new_scope, current_block_scope_);
}
void VisitBinding_(const VarBindingNode* binding, const CallNode* call)
final {
@@ -273,8 +336,10 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
static const auto& builtin_alloc_tensor_op =
Op::Get("relax.builtin.alloc_tensor");
static const auto& call_builtin_with_ctx_op =
Op::Get("relax.call_builtin_with_ctx");
- if (call->op.same_as(mem_alloc_storage_op) &&
IsStaticAllocStorage(binding)) {
- AddStaticBinding(binding, /*is_alloc_storage=*/true);
+ if (call->op.same_as(mem_alloc_storage_op)) {
+ if (IsStaticAllocStorage(binding)) {
+ AddStaticBinding(binding, /*is_alloc_storage=*/true);
+ }
return;
} else if (call->op.same_as(builtin_alloc_tensor_op)) {
return;
@@ -321,7 +386,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
}
return false;
}();
- if (current_.capture_builder == nullptr && is_kernel_launch) {
+ if (current_block_scope_.capture_builder == nullptr && is_kernel_launch)
{
StartRegion();
}
AddStaticBinding(binding, /*is_alloc_storage=*/false);
@@ -335,24 +400,24 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
void MarkAsFuncInput(const std::vector<const VarNode*>& vars,
const std::vector<const tir::VarNode*>& tir_vars = {}) {
- if (current_.capture_builder == nullptr) {
+ if (current_block_scope_.capture_builder == nullptr) {
return;
}
for (const VarNode* var : vars) {
auto it = binding_to_region_.find(var);
- if (it == binding_to_region_.end() || it->second !=
current_.capture_builder) {
- current_.capture_builder->MarkInput(var);
+ if (it == binding_to_region_.end() || it->second !=
current_block_scope_.capture_builder) {
+ current_block_scope_.capture_builder->MarkInput(var);
}
}
for (const tir::VarNode* tir_var : tir_vars) {
- current_.capture_builder->MarkShapeExprInput(tir_var);
+ current_block_scope_.capture_builder->MarkShapeExprInput(tir_var);
}
}
void MarkAsFuncOutput(const std::vector<const VarNode*>& vars) {
for (const VarNode* var : vars) {
if (auto it = binding_to_region_.find(var);
- it != binding_to_region_.end() && it->second !=
current_.capture_builder) {
+ it != binding_to_region_.end() && it->second !=
current_block_scope_.capture_builder) {
it->second->MarkOutput(var);
}
}
@@ -476,6 +541,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
private:
bool IsStaticAllocStorage(const VarBindingNode* binding) {
+ if (disabled_storage_vars_.count(binding->var.get())) {
+ return false;
+ }
// Check if the allocation has constant shape
const auto* alloc_storage_call = binding->value.as<CallNode>();
auto shape = Downcast<ShapeExpr>(alloc_storage_call->args[0]);
@@ -491,33 +559,41 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
*/
void AddStaticBinding(const VarBindingNode* binding, bool is_alloc_storage) {
if (is_alloc_storage) {
- current_.alloc_storage_builder->AddBinding(binding);
- binding_to_region_[binding->var.get()] = current_.alloc_storage_builder;
- } else if (current_.capture_builder != nullptr) {
+ current_function_scope_.alloc_storage_builder->AddBinding(binding);
+ binding_to_region_[binding->var.get()] =
current_function_scope_.alloc_storage_builder;
+ } else if (current_block_scope_.capture_builder != nullptr) {
// Add the binding if the capture builder exists. It is possible that
capture builder is
// null when it is not capturing. This is the case that there are not
yet any kernel launches
// encountered, in this case static bindings (e.g. binding of other
non-kernel-launch
// operations) are marked but are not lifted.
- current_.capture_builder->AddBinding(binding);
- binding_to_region_[binding->var.get()] = current_.capture_builder;
+ current_block_scope_.capture_builder->AddBinding(binding);
+ binding_to_region_[binding->var.get()] =
current_block_scope_.capture_builder;
}
static_vars_.emplace(binding->var.get());
}
- /*! \brief The states of the current scope (the BindingBlock) which is a
pair of FuncBuilder.
+ /*! \brief The states of the current scope (the BindingBlock) which is a
FuncBuilder.
* The FuncBuilder are initialized with nullptr, meaning the planner is
currently not doing any
* lifting. They are initialized lazily when a binding that can be lifted is
encountered.
* They are reset to nullptr when an unsupported operation is encountered.
*/
- struct Scope {
+ struct BindingBlockScope {
+ FuncBuilder* capture_builder = nullptr; // The builder for the capture
function
+ };
+
+ /*! \brief The states of the current function scope which is a FuncBuilder
to build the storage
+ * allocation function.
+ */
+ struct FunctionScope {
FuncBuilder* alloc_storage_builder = nullptr; // The builder for the
allocation function
- FuncBuilder* capture_builder = nullptr; // The builder for the
capture function
};
// The IRModule
IRModule mod_;
- // States of the current scope
- Scope current_;
+ // States of the current block scope
+ BindingBlockScope current_block_scope_;
+ // States of the current function scope
+ FunctionScope current_function_scope_;
// Variables whose buffer address is fixed
std::unordered_set<const VarNode*> static_vars_;
// The name of the variables that are allowed to be symbolic
@@ -529,64 +605,183 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
std::vector<FuncBuilder*> captured_regions_;
// The regions for allocation.
std::vector<FuncBuilder*> alloc_storages_;
+ // The binding variables that are not allowed to be captured.
+ std::unordered_set<const VarNode*> disabled_storage_vars_;
// The arena.
- support::Arena arena_;
+ support::Arena* arena_;
};
+/*!
+ * \brief Merge storage allocations from different functions by reusing the
largest allocation that
+ * can be shared among all the functions. The original rewriting plans are
updated in-place to use
+ * the merged storage allocations.
+ *
+ * When multiple functions are rewritten to be executed with CUDA graph, the
storage allocations
+ * from different functions can be reused. This functions merge multiple
storage allocations
+ * functions to a single function that allocates the sufficiently large
storage to be shared among
+ * all the functions.
+ *
+ * \param alloc_plans The allocation plans of the functions to be merged.
+ * \return The new allocation function that merges the storage allocations.
+ */
+Function MergeAllocationPlans(const std::vector<LiftedFunctionRewritePlan*>&
alloc_plans) {
+ // The storage record that contains the size of the storage allocation and
the binding of the
+ // storage allocation.
+ struct StorageRecord {
+ // The size of the storage object in bytes
+ int64_t size;
+ // The binding of the storage allocation
+ const VarBindingNode* binding;
+ // The source rewriting plan that the storage record is from
+ LiftedFunctionRewritePlan* src;
+
+ bool operator<(const StorageRecord& other) const { return size <
other.size; }
+ };
+ // Using an (ordered) map to make sure the result is deterministic
+ std::map<String, std::vector<std::vector<StorageRecord>>> storage_records;
+ static const auto& mem_alloc_storage_op =
Op::Get("relax.memory.alloc_storage");
+
+ // Collect the storage records for each storage scope. Storage records are
stored separately
+ // for each original function.
+ for (int plan_id = 0; plan_id < static_cast<int>(alloc_plans.size());
++plan_id) {
+ LiftedFunctionRewritePlan* plan = alloc_plans[plan_id];
+ ICHECK(plan->is_alloc);
+ for (const VarBindingNode* binding : plan->lifted_bindings) {
+ // Extract the stroage record from the Call expr.
+ Call alloc_storage = Downcast<Call>(binding->value);
+ ICHECK(alloc_storage->op.same_as(mem_alloc_storage_op));
+ auto storage_shape = Downcast<ShapeExpr>(alloc_storage->args[0]);
+ ICHECK_EQ(storage_shape->values.size(), 1);
+ int64_t size = Downcast<IntImm>(storage_shape->values[0])->value;
+ int64_t virtual_device_id =
+
Downcast<IntImm>(Downcast<PrimValue>(alloc_storage->args[1])->value)->value;
+ ICHECK_EQ(virtual_device_id, 0);
+ String storage_scope =
Downcast<StringImm>(alloc_storage->args[2])->value;
+ auto [it, _] = storage_records.try_emplace(storage_scope,
alloc_plans.size());
+ it->second[plan_id].emplace_back(StorageRecord{size, binding, plan});
+ }
+ }
+
+ // Merge the storage records within each storage scope.
+ // This is achieved by sorting the storage records in descending order of
size and then merging
+ // storage allocations from different functions to the largest allocation
that can be shared
+ // among all the functions.
+ // This assumes that multiple functions will not run concurrently.
+ std::vector<const VarBindingNode*> merged_allocs;
+ // Merge the storage records within each storage scope.
+ for (auto& [storage_scope, curr_scope_records] : storage_records) {
+ // The number of storages needed for the current storage scope, which is
the maximum number of
+ // storage records among all the functions.
+ int num_storages = 0;
+ for (auto& records_of_plan : curr_scope_records) {
+ // Sort descending by size, preserve the original order if the sizes are
equal.
+ std::stable_sort(records_of_plan.rbegin(), records_of_plan.rend());
+ num_storages = std::max(num_storages,
static_cast<int>(records_of_plan.size()));
+ }
+ // The iterators to scan the storage records of all functions from the
left to the right
+ // at the same time.
+ std::vector<int> iters(alloc_plans.size(), 0);
+ for (int i = 0; i < num_storages; i++) {
+ // The storage records from different functions that can be merged to
the same storage.
+ std::vector<StorageRecord> to_merge;
+ for (int plan_index = 0; plan_index <
static_cast<int>(curr_scope_records.size());
+ plan_index++) {
+ if (iters[plan_index] <
static_cast<int>(curr_scope_records[plan_index].size())) {
+
to_merge.push_back(curr_scope_records[plan_index][iters[plan_index]++]);
+ }
+ }
+ const StorageRecord& largest_storage =
+ *std::max_element(to_merge.begin(), to_merge.end(),
+ [](const auto& lhs, const auto& rhs) { return lhs
< rhs; });
+ // Merge the records to the largest allocation by updating the index of
the output element
+ // to that of the new allocation function.
+ int storage_index = static_cast<int>(merged_allocs.size());
+ for (const StorageRecord& rec : to_merge) {
+ auto* plan = rec.src;
+ plan->outputs.at(rec.binding->var.get()) = storage_index;
+ }
+ merged_allocs.push_back(largest_storage.binding);
+ }
+ }
+ // Create the new allocation function for the merged allocations.
+ FuncBuilder builder;
+ for (const auto* binding : merged_allocs) {
+ builder.AddBinding(binding);
+ builder.MarkOutput(binding->var.get());
+ }
+ return builder.Build();
+}
+
/*! \brief The rewriter for CUDA graph */
class CUDAGraphRewriter : public ExprMutator {
public:
explicit CUDAGraphRewriter(const IRModule& mod) : ExprMutator(mod) {}
IRModule Rewrite() {
- CUDAGraphRewritePlanner planner(builder_->GetContextIRModule());
- auto plans = planner.Plan();
- for (const auto& plan : plans) {
- subgraph_launches_[plan.launch_point] = plan;
- }
+ CUDAGraphRewritePlanner planner(builder_->GetContextIRModule(), &arena_);
+ // Collect the target functions for rewriting before any mutation.
+ std::vector<std::pair<GlobalVar, Function>> target_functions;
for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
if (func->IsInstance<FunctionNode>()) {
- auto new_func = Downcast<Function>(VisitExpr(func));
- if (!new_func.same_as(func)) {
- builder_->UpdateFunction(gv, new_func);
- }
+ target_functions.emplace_back(gv, Downcast<Function>(func));
+ }
+ }
+
+ auto [alloc_plans, capture_plans] = planner.Plan();
+ if (alloc_plans.size()) {
+ auto global_alloc_func = MergeAllocationPlans(alloc_plans);
+ gv_global_alloc_ = builder_->AddFunction(global_alloc_func,
"cuda_graph_alloc");
+ }
+ for (const auto* plan : alloc_plans) {
+ subgraph_launches_[plan->launch_point] = plan;
+ }
+ for (const auto* plan : capture_plans) {
+ subgraph_launches_[plan->launch_point] = plan;
+ }
+
+ for (const auto& [gv, func] : target_functions) {
+ current_func_ = gv;
+ auto new_func = Downcast<Function>(VisitExpr(func));
+ if (!new_func.same_as(func)) {
+ builder_->UpdateFunction(gv, new_func);
}
}
return builder_->GetContextIRModule();
}
- void LaunchSubgraph(const VarBindingNode* op, const
LiftedFunctionRewritePlan& plan) {
+ void LaunchSubgraph(const VarBindingNode* op, const
LiftedFunctionRewritePlan* plan) {
static const auto& call_builtin_with_ctx_op =
Op::Get("relax.call_builtin_with_ctx");
static const auto& builtin_run_or_capture =
ExternFunc("vm.builtin.cuda_graph.run_or_capture");
static const auto& builtin_get_cached_alloc =
ExternFunc("vm.builtin.cuda_graph.get_cached_alloc");
Expr launch_subgraph;
- auto gv_func =
- builder_->AddFunction(plan.func, plan.is_alloc ? "cuda_graph_alloc" :
"cuda_graph_capture");
- if (plan.is_alloc) {
+ if (plan->is_alloc) {
// Storage allocation should be fully static and shouldn't depend on any
symbolic variables.
- ICHECK(!plan.propogated_tir_vars.defined());
- ICHECK(plan.inputs.empty());
- launch_subgraph =
- Call(call_builtin_with_ctx_op,
- {builtin_get_cached_alloc,
- Tuple({gv_func, PrimValue(IntImm(DataType::Int(64),
index_alloc_++))})},
- Attrs(), {plan.func->ret_struct_info});
+ ICHECK(!plan->propogated_tir_vars.defined());
+ ICHECK(plan->inputs.empty());
+ auto gv_alloc = gv_global_alloc_.value();
+ auto ret_struct_info =
Downcast<FuncStructInfo>(gv_alloc->struct_info_.value())->ret;
+ launch_subgraph = Call(
+ call_builtin_with_ctx_op,
+ {builtin_get_cached_alloc, Tuple({gv_alloc,
PrimValue(IntImm(DataType::Int(64), 0))})},
+ Attrs(), {ret_struct_info});
} else {
- StructInfo call_sinfo = plan.func->ret_struct_info;
+ auto gv_func = builder_->AddFunction(
+ plan->func, current_func_.value()->name_hint +
"_cuda_graph_capture");
+ StructInfo call_sinfo = plan->func->ret_struct_info;
// Arguments of the lifted function
Array<Expr> args;
- for (const auto& arg : plan.inputs) {
+ for (const auto& arg : plan->inputs) {
args.push_back(VisitExpr_(arg));
}
- if (plan.propogated_tir_vars.defined()) {
- ShapeExpr propogated_tir_vars = plan.propogated_tir_vars.value();
+ if (plan->propogated_tir_vars.defined()) {
+ ShapeExpr propogated_tir_vars = plan->propogated_tir_vars.value();
args.push_back(propogated_tir_vars);
// The ret_struct_info of the lifted function can contain symbolic
variables. We need to
// bind the symbolic parameters to the actual values.
- const auto& shape_expr = plan.func->params.back();
+ const auto& shape_expr = plan->func->params.back();
auto symbolic_params =
Downcast<ShapeStructInfo>(shape_expr->struct_info_.value())->values.value();
Map<tir::Var, PrimExpr> tir_var_remap;
@@ -599,25 +794,23 @@ class CUDAGraphRewriter : public ExprMutator {
// Arguments of builtin_run_or_capture
Array<Expr> tuple_arg_fields{gv_func, Tuple(args),
PrimValue(IntImm(DataType::Int(64),
index_capture_++))};
- if (plan.propogated_tir_vars.defined()) {
+ if (plan->propogated_tir_vars.defined()) {
// The shape expr is explicitly passed twice, one as the last argument
of the lifted
// function, one as the last argument of builtin_run_or_capture as the
cache key. Explicitly
// passing it twice simplifies the handling during the capture phase.
- tuple_arg_fields.push_back(plan.propogated_tir_vars.value());
+ tuple_arg_fields.push_back(plan->propogated_tir_vars.value());
}
launch_subgraph =
Call(call_builtin_with_ctx_op, {builtin_run_or_capture,
Tuple(tuple_arg_fields)}, Attrs(),
{call_sinfo});
}
Expr ret_value = builder_->Emit(launch_subgraph);
- for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) {
- // The unpacked result is saved in the var_redef_. It will be emitted
when 1) the var
- // definition is the original IR is visited, or 2) the var is used as an
input to another
- // lifted function, whichever comes first.
- var_redef_[plan.outputs[i]] = TupleGetItem(ret_value, i);
+ for (const auto& [var, tuple_index] : plan->outputs) {
+ var_redef_[var] = TupleGetItem(ret_value, tuple_index);
}
-
- lifted_bindings_.insert(plan.lifted_bindings.begin(),
plan.lifted_bindings.end());
+ std::transform(plan->lifted_bindings.begin(), plan->lifted_bindings.end(),
+ std::inserter(lifted_binding_vars_,
lifted_binding_vars_.end()),
+ [](const BindingNode* binding) { return binding->var.get();
});
}
void VisitBinding_(const VarBindingNode* op) final {
@@ -629,7 +822,7 @@ class CUDAGraphRewriter : public ExprMutator {
EmitRedef(op->var.get(), it->second);
return;
}
- if (lifted_bindings_.count(op->var.get())) {
+ if (lifted_binding_vars_.count(op->var.get())) {
// The binding is lifted to the subgraph and will be removed from the
original function.
return;
}
@@ -654,11 +847,14 @@ class CUDAGraphRewriter : public ExprMutator {
return new_var;
}
- std::unordered_map<const VarNode*, LiftedFunctionRewritePlan>
subgraph_launches_;
+ std::unordered_map<const VarNode*, const LiftedFunctionRewritePlan*>
subgraph_launches_;
std::unordered_map<const VarNode*, Expr> var_redef_;
- std::unordered_set<const VarNode*> lifted_bindings_;
+ std::unordered_set<const VarNode*> lifted_binding_vars_;
int index_alloc_ = 0;
int index_capture_ = 0;
+ support::Arena arena_;
+ Optional<GlobalVar> gv_global_alloc_ = NullOpt;
+ Optional<GlobalVar> current_func_ = NullOpt;
};
IRModule RewriteCUDAGraph(IRModule mod) {
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 43b26f110f..9db285fea6 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -107,7 +107,7 @@ def test_rewrite_cuda_graph():
return gv
@R.function(private=True)
- def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"),
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2:
R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
+ def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"),
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2:
R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
R.func_attr({"relax.force_pure": True})
cls = Expected
_2: R.Tuple = cls.exp(alloc, alloc1)
@@ -133,7 +133,7 @@ def test_rewrite_cuda_graph():
storage1: R.Object = gv[1]
alloc1: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]),
R.dtype("float32"))
storage2: R.Object = gv[2]
- gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture",
(cls.cuda_graph_capture, (alloc, alloc1, storage, storage2), R.prim_value(0)),
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
+ gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture",
(cls.main_cuda_graph_capture, (alloc, alloc1, storage, storage2),
R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
alloc3: R.Tensor((2, 4), dtype="float32") = gv1[0]
alloc4: R.Tensor((2, 4), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
_6: R.Tuple = cls.exp(alloc3, alloc4)
@@ -191,7 +191,7 @@ def test_tuple():
_5: R.Tuple = R.memory.kill_tensor(alloc2)
_6: R.Tuple = R.memory.kill_storage(storage)
_7: R.Tuple = R.memory.kill_storage(storage1)
- return alloc2
+ return alloc3
@I.ir_module
class Expected:
@@ -217,7 +217,7 @@ def test_tuple():
return gv
@R.function(private=True)
- def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"),
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) ->
R.Tuple(R.Tensor((2, 4), dtype="float32")):
+ def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"),
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) ->
R.Tuple(R.Tensor((2, 4), dtype="float32")):
R.func_attr({"relax.force_pure": True})
cls = Expected
_: R.Tuple = cls.exp(alloc, alloc1)
@@ -242,14 +242,14 @@ def test_tuple():
_: R.Tuple = cls.exp(x, alloc)
storage1: R.Object = gv[1]
alloc1: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]),
R.dtype("float32"))
- gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture",
(cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)),
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
+ gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture",
(cls.main_cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)),
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0]
alloc3: R.Tensor((2, 4), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
_4: R.Tuple = cls.exp(alloc2, alloc3)
_5: R.Tuple = R.memory.kill_tensor(alloc2)
_6: R.Tuple = R.memory.kill_storage(storage)
_7: R.Tuple = R.memory.kill_storage(storage1)
- return alloc2
+ return alloc3
# fmt: on
after = relax.transform.RewriteCUDAGraph()(Before)
@@ -318,7 +318,7 @@ def test_vm_builtin():
return gv
@R.function(private=True)
- def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"),
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) ->
R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")):
+ def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"),
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) ->
R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")):
R.func_attr({"relax.force_pure": True})
cls = Expected
_2: R.Tuple = cls.exp(alloc, alloc1)
@@ -338,7 +338,7 @@ def test_vm_builtin():
_1: R.Tuple = cls.exp(x, alloc)
storage1: R.Object = gv[1]
alloc1: R.Tensor((2, 4), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]),
R.dtype("float32"))
- gv1: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4),
dtype="float32")) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture",
(cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)),
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4),
dtype="float32")),))
+ gv1: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4),
dtype="float32")) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture",
(cls.main_cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)),
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4),
dtype="float32")),))
alloc2: R.Tensor((2, 4), dtype="float32") = gv1[1]
lv: R.Tensor((2, 4), dtype="float32") = gv1[0]
_4: R.Tuple = R.call_packed("vm.builtin.dummy", (x, lv),
sinfo_args=(R.Tuple,))
@@ -528,7 +528,7 @@ def test_capture_fixed_inputs():
return gv
@R.function(private=True)
- def cuda_graph_capture(
+ def main_cuda_graph_capture(
lv: R.Tensor((16, 32, 32, 16), dtype="float16"),
lv1: R.Tensor((16, 3, 3, 16), dtype="float16"),
alloc1: R.Tensor((16, 32, 32, 16), dtype="float16"),
@@ -635,7 +635,7 @@ def test_capture_fixed_inputs():
) = R.call_builtin_with_ctx(
"vm.builtin.cuda_graph.run_or_capture",
(
- cls.cuda_graph_capture,
+ cls.main_cuda_graph_capture,
(lv_1, lv1, alloc1, alloc, params, storage),
R.prim_value(0),
),
@@ -728,7 +728,7 @@ def test_static_args():
return gv
@R.function(private=True)
- def cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) ->
R.Tuple:
+ def main_cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32"))
-> R.Tuple:
R.func_attr({"relax.force_pure": True})
_: R.Object = R.call_packed("dummy_func", alloc0,
R.dtype("float32"), R.str("string"))
gv: R.Tuple = R.tuple()
@@ -748,7 +748,7 @@ def test_static_args():
)
gv1: R.Tuple = R.call_builtin_with_ctx(
"vm.builtin.cuda_graph.run_or_capture",
- (cls.cuda_graph_capture, (alloc0,), R.prim_value(0)),
+ (cls.main_cuda_graph_capture, (alloc0,), R.prim_value(0)),
sinfo_args=(R.Tuple,),
)
return R.tuple()
@@ -822,7 +822,7 @@ def test_dynamic_capture():
return gv
@R.function(private=True)
- def cuda_graph_capture(
+ def main_cuda_graph_capture(
alloc1: R.Tensor(("m",), dtype="float32"),
alloc2: R.Tensor(("m",), dtype="float32"),
shape_expr: R.Shape(["m"]),
@@ -858,7 +858,7 @@ def test_dynamic_capture():
R.call_builtin_with_ctx(
"vm.builtin.cuda_graph.run_or_capture",
(
- cls.cuda_graph_capture,
+ cls.main_cuda_graph_capture,
(alloc1, alloc2, R.shape([m])),
R.prim_value(0),
R.shape([m]),
@@ -875,5 +875,218 @@ def test_dynamic_capture():
tvm.ir.assert_structural_equal(mod, Expected)
+class TestMergeAllocFuncs(BaseCompare):
+ @I.ir_module
+ class Before:
+ @R.function
+ def func1():
+ R.func_attr({"relax.force_pure": True})
+ storage1 = R.memory.alloc_storage(R.shape([128]), 0, "global",
"float32")
+ storage2 = R.memory.alloc_storage(R.shape([256]), 0, "global",
"float32")
+ storage3 = R.memory.alloc_storage(R.shape([512]), 0, "ipc_memory",
"float32")
+ alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([128]),
"float32")
+ alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([256]),
"float32")
+ alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([512]),
"float32")
+ R.call_packed("dummy", alloc1, alloc2, alloc3,
sinfo_args=(R.Tuple,))
+ return R.tuple()
+
+ @R.function
+ def func2():
+ R.func_attr({"relax.force_pure": True})
+ storage1 = R.memory.alloc_storage(R.shape([192]), 0, "global",
"float32")
+ storage2 = R.memory.alloc_storage(R.shape([64]), 0, "global",
"float32")
+ storage3 = R.memory.alloc_storage(R.shape([1024]), 0,
"ipc_memory", "float32")
+ storage4 = R.memory.alloc_storage(R.shape([512]), 0, "global",
"float32")
+ alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([192]),
"float32")
+ alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([64]),
"float32")
+ alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([1024]),
"float32")
+ alloc4 = R.memory.alloc_tensor(storage4, 0, R.shape([512]),
"float32")
+ R.call_packed("dummy", alloc1, alloc2, alloc3, alloc4,
sinfo_args=(R.Tuple,))
+ return R.tuple()
+
+ @I.ir_module
+ class Expected:
+ @R.function(private=True)
+ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object,
R.Object):
+ R.func_attr({"relax.force_pure": True})
+ storage4: R.Object = R.memory.alloc_storage(
+ R.shape([512]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([192]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ storage2: R.Object = R.memory.alloc_storage(
+ R.shape([64]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ storage3: R.Object = R.memory.alloc_storage(
+ R.shape([1024]), R.prim_value(0), R.str("ipc_memory"),
R.dtype("float32")
+ )
+ gv: R.Tuple(R.Object, R.Object, R.Object, R.Object) = (
+ storage4,
+ storage1,
+ storage2,
+ storage3,
+ )
+ return gv
+
+ @R.function
+ def func1() -> R.Tuple:
+ R.func_attr({"relax.force_pure": True})
+ cls = Expected
+ gv: R.Tuple(R.Object, R.Object, R.Object, R.Object) =
R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.get_cached_alloc",
+ (cls.cuda_graph_alloc, R.prim_value(0)),
+ sinfo_args=(R.Tuple(R.Object, R.Object, R.Object, R.Object),),
+ )
+ storage1: R.Object = gv[1]
+ storage2: R.Object = gv[0]
+ storage3: R.Object = gv[3]
+ alloc1: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor(
+ storage1, R.prim_value(0), R.shape([128]), R.dtype("float32")
+ )
+ alloc2: R.Tensor((256,), dtype="float32") = R.memory.alloc_tensor(
+ storage2, R.prim_value(0), R.shape([256]), R.dtype("float32")
+ )
+ alloc3: R.Tensor((512,), dtype="float32") = R.memory.alloc_tensor(
+ storage3, R.prim_value(0), R.shape([512]), R.dtype("float32")
+ )
+ R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.run_or_capture",
+ (cls.func1_cuda_graph_capture, (alloc1, alloc2, alloc3),
R.prim_value(0)),
+ sinfo_args=(R.Tuple,),
+ )
+ return R.tuple()
+
+ @R.function(private=True)
+ def func1_cuda_graph_capture(
+ alloc1: R.Tensor((128,), dtype="float32"),
+ alloc2: R.Tensor((256,), dtype="float32"),
+ alloc3: R.Tensor((512,), dtype="float32"),
+ ) -> R.Tuple:
+ R.func_attr({"relax.force_pure": True})
+ R.call_packed("dummy", alloc1, alloc2, alloc3,
sinfo_args=(R.Tuple,))
+ R.tuple()
+ return R.tuple()
+
+ @R.function
+ def func2() -> R.Tuple:
+ R.func_attr({"relax.force_pure": True})
+ cls = Expected
+ gv2: R.Tuple(R.Object, R.Object, R.Object, R.Object) =
R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.get_cached_alloc",
+ (cls.cuda_graph_alloc, R.prim_value(0)),
+ sinfo_args=(R.Tuple(R.Object, R.Object, R.Object, R.Object),),
+ )
+ storage11: R.Object = gv2[1]
+ storage21: R.Object = gv2[2]
+ storage31: R.Object = gv2[3]
+ storage4: R.Object = gv2[0]
+ alloc1: R.Tensor((192,), dtype="float32") = R.memory.alloc_tensor(
+ storage11, R.prim_value(0), R.shape([192]), R.dtype("float32")
+ )
+ alloc2: R.Tensor((64,), dtype="float32") = R.memory.alloc_tensor(
+ storage21, R.prim_value(0), R.shape([64]), R.dtype("float32")
+ )
+ alloc3: R.Tensor((1024,), dtype="float32") = R.memory.alloc_tensor(
+ storage31, R.prim_value(0), R.shape([1024]), R.dtype("float32")
+ )
+ alloc4: R.Tensor((512,), dtype="float32") = R.memory.alloc_tensor(
+ storage4, R.prim_value(0), R.shape([512]), R.dtype("float32")
+ )
+ R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.run_or_capture",
+ (cls.func2_cuda_graph_capture, (alloc1, alloc2, alloc3,
alloc4), R.prim_value(1)),
+ sinfo_args=(R.Tuple,),
+ )
+ return R.tuple()
+
+ @R.function(private=True)
+ def func2_cuda_graph_capture(
+ alloc1: R.Tensor((192,), dtype="float32"),
+ alloc2: R.Tensor((64,), dtype="float32"),
+ alloc3: R.Tensor((1024,), dtype="float32"),
+ alloc4: R.Tensor((512,), dtype="float32"),
+ ) -> R.Tuple:
+ R.func_attr({"relax.force_pure": True})
+ R.call_packed("dummy", alloc1, alloc2, alloc3, alloc4,
sinfo_args=(R.Tuple,))
+ R.tuple()
+ return R.tuple()
+
+
+class TestDisableCaptureOutput(BaseCompare):
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((8,), "float32")) -> R.Tuple(R.Tensor((8,),
"float32")):
+ R.func_attr({"relax.force_pure": True})
+ storage1 = R.memory.alloc_storage(R.shape([8]), 0, "global",
"float32")
+ alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([8]),
"float32")
+ _ = R.call_packed("dummy", x, alloc1, sinfo_args=(R.Tuple,))
+ storage2 = R.memory.alloc_storage(R.shape([8]), 0, "global",
"float32")
+ alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([8]),
"float32")
+ _1 = R.call_packed("dummy", alloc1, alloc2, sinfo_args=(R.Tuple,))
+ storage3 = R.memory.alloc_storage(R.shape([8]), 0, "global",
"float32")
+ alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([8]),
"float32")
+ _2 = R.call_packed("dummy", alloc2, alloc3, sinfo_args=(R.Tuple,))
+ gv = (alloc3,)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function(private=True)
+ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+ R.func_attr({"relax.force_pure": True})
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([8]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ storage2: R.Object = R.memory.alloc_storage(
+ R.shape([8]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ gv: R.Tuple(R.Object, R.Object) = storage1, storage2
+ return gv
+
+ @R.function(private=True)
+ def main_cuda_graph_capture(
+ alloc1: R.Tensor((8,), dtype="float32"), alloc2: R.Tensor((8,),
dtype="float32")
+ ) -> R.Tuple:
+ R.func_attr({"relax.force_pure": True})
+ R.call_packed("dummy", alloc1, alloc2, sinfo_args=(R.Tuple,))
+ R.tuple()
+ return R.tuple()
+
+ @R.function
+ def main(x: R.Tensor((8,), dtype="float32")) -> R.Tuple(R.Tensor((8,),
dtype="float32")):
+ R.func_attr({"relax.force_pure": True})
+ cls = Expected
+ gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.get_cached_alloc",
+ (cls.cuda_graph_alloc, R.prim_value(0)),
+ sinfo_args=(R.Tuple(R.Object, R.Object),),
+ )
+ storage1: R.Object = gv[0]
+ alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
+ storage1, R.prim_value(0), R.shape([8]), R.dtype("float32")
+ )
+ R.call_packed("dummy", x, alloc1, sinfo_args=(R.Tuple,))
+ storage2: R.Object = gv[1]
+ alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
+ storage2, R.prim_value(0), R.shape([8]), R.dtype("float32")
+ )
+ R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.run_or_capture",
+ (cls.main_cuda_graph_capture, (alloc1, alloc2),
R.prim_value(0)),
+ sinfo_args=(R.Tuple,),
+ )
+ storage3: R.Object = R.memory.alloc_storage(
+ R.shape([8]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ alloc3: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
+ storage3, R.prim_value(0), R.shape([8]), R.dtype("float32")
+ )
+ R.call_packed("dummy", alloc2, alloc3, sinfo_args=(R.Tuple,))
+ gv = (alloc3,)
+ return gv
+
+
if __name__ == "__main__":
tvm.testing.main()