tqchen commented on code in PR #16815:
URL: https://github.com/apache/tvm/pull/16815#discussion_r1544103179
##########
src/relax/transform/rewrite_cuda_graph.cc:
##########
@@ -343,48 +397,83 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
}
bool IsStatic(const PrimExpr& expr,
- [[maybe_unused]] std::vector<const VarNode*>* vars_collector =
nullptr) {
- return expr->IsInstance<tir::IntImmNode>() ||
expr->IsInstance<tir::FloatImmNode>();
+ [[maybe_unused]] std::vector<const VarNode*>* vars_collector =
nullptr,
+ std::vector<const tir::VarNode*>* tir_vars_collector =
nullptr) {
+ bool is_static = true;
+ tir::PostOrderVisit(expr, [&](const ObjectRef& e) {
+ if (auto var = e.as<tir::VarNode>()) {
+ if (!capture_symbolic_vars_.count(var->name_hint)) {
+ is_static = false;
+ return;
+ }
+ if (tir_vars_collector != nullptr) {
+ tir_vars_collector->push_back(var);
+ }
+ }
+ });
+ return is_static;
}
- bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector
= nullptr) {
+ bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector
= nullptr,
+ std::vector<const tir::VarNode*>* tir_vars_collector =
nullptr) {
if (expr->IsInstance<ConstantNode>() ||
expr->IsInstance<DataTypeImmNode>() ||
- expr->IsInstance<StringImmNode>()) {
+ expr->IsInstance<StringImmNode>() ||
expr->IsInstance<GlobalVarNode>()) {
return true;
}
if (const auto* prim_value = expr.as<PrimValueNode>()) {
- return IsStatic(prim_value->value, vars_collector);
+ return IsStatic(prim_value->value, vars_collector, tir_vars_collector);
}
if (const auto* var = expr.as<VarNode>()) {
if (vars_collector != nullptr) {
vars_collector->push_back(var);
}
- return static_vars_.count(var);
+ // recursively check the struct info to collect the symbolic TIR vars
+ return static_vars_.count(var) &&
IsStatic(Downcast<StructInfo>(var->struct_info_.value()),
+ vars_collector,
tir_vars_collector);
}
if (const auto* shape = expr.as<ShapeExprNode>()) {
- return IsStatic(shape->values, vars_collector);
+ return IsStatic(shape->values, vars_collector, tir_vars_collector);
}
if (const auto* tuple = expr.as<TupleNode>()) {
- return IsStatic(tuple->fields, vars_collector);
+ return IsStatic(tuple->fields, vars_collector, tir_vars_collector);
}
return false;
}
template <typename T>
- bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>*
vars_collector = nullptr) {
+ bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>*
vars_collector = nullptr,
+ std::vector<const tir::VarNode*>* tir_vars_collector =
nullptr) {
bool result = true;
for (const auto& expr : exprs) {
// If vars_collector is provided, we will collect all the vars in the
exprs and we should
// not perform short-circuiting.
- result &= IsStatic(expr, vars_collector);
- if (!vars_collector && !result) {
+ result &= IsStatic(expr, vars_collector, tir_vars_collector);
+ if (vars_collector == nullptr && tir_vars_collector == nullptr &&
!result) {
return false;
}
}
return result;
}
+ bool IsStatic(const StructInfo& sinfo, std::vector<const VarNode*>*
vars_collector = nullptr,
+ std::vector<const tir::VarNode*>* tir_vars_collector =
nullptr) {
Review Comment:
considee use tir::Var ith reference instead to avoid var de-allocation
during rewrite
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]