zhuwenxi commented on a change in pull request #7619:
URL: https://github.com/apache/tvm/pull/7619#discussion_r603074923



##########
File path: src/tir/transforms/lower_tvm_builtin.cc
##########
@@ -227,34 +271,36 @@ class BuiltinLower : public StmtExprMutator {
       if (t != api_type) {
         arg = Cast(api_type, arg);
       }
-      prep_seq_.emplace_back(TVMStructSet(stack_value_, 
static_cast<int>(arg_stack_begin + i - 1),
+      prep_seq_.emplace_back(TVMStructSet(scope.stack_value_,
+                                          static_cast<int>(arg_stack_begin + i 
- 1),
                                           builtin::kTVMValueContent, arg));
       int arg_tcode = api_type.code();
       if (api_type.is_handle() && arg.as<StringImmNode>()) {
         arg_tcode = kTVMStr;
       }
       if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
       prep_seq_.emplace_back(
-          Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, 
const_true(1)));
+          Store(scope.stack_tcode_, ConstInt32(arg_tcode), stack_index, 
const_true(1)));
     }
     // UPDATE stack value
-    max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
-    max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
-    max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
-    run_shape_stack_ = restore_shape_stack;
-    run_array_stack_ = restore_array_stack;
-    run_arg_stack_ = arg_stack_begin;
-    Array<PrimExpr> packed_args = {op->args[0], stack_value_, stack_tcode_,
+    scope.max_arg_stack_ = std::max(scope.run_arg_stack_, 
scope.max_arg_stack_);
+    scope.max_shape_stack_ = std::max(scope.run_shape_stack_, 
scope.max_shape_stack_);
+    scope.max_array_stack_ = std::max(scope.run_array_stack_, 
scope.max_array_stack_);
+    scope.run_shape_stack_ = restore_shape_stack;
+    scope.run_array_stack_ = restore_array_stack;
+    scope.run_arg_stack_ = arg_stack_begin;
+    Array<PrimExpr> packed_args = {op->args[0], scope.stack_value_, 
scope.stack_tcode_,
                                    ConstInt32(arg_stack_begin),
                                    ConstInt32(arg_stack_begin + 
op->args.size() - 1)};
     return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), 
packed_args);
   }
 
   PrimExpr MakeCallTracePacked(const CallNode* op) {
-    int64_t restore_shape_stack = run_shape_stack_;
-    size_t restore_array_stack = run_array_stack_;
-    size_t arg_stack_begin = run_arg_stack_;
-    run_arg_stack_ += op->args.size();
+    auto& scope = alloca_scope_.back();

Review comment:
       Done.

##########
File path: src/tir/transforms/lower_tvm_builtin.cc
##########
@@ -158,64 +199,67 @@ class BuiltinLower : public StmtExprMutator {
   // call shape
   PrimExpr MakeShape(const CallNode* op) {
     // if args.size() == 0, it represents a scalar shape ()
-    if (run_shape_stack_ == -1) {
-      run_shape_stack_ = 0;
+    auto& scope = alloca_scope_.back();
+    if (scope.run_shape_stack_ == -1) {
+      scope.run_shape_stack_ = 0;
     }
-    int64_t stack_begin = run_shape_stack_;
-    run_shape_stack_ += op->args.size();
+    int64_t stack_begin = scope.run_shape_stack_;
+    scope.run_shape_stack_ += op->args.size();
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
     // no need to perform any store for a scalar shape
     for (size_t i = 0; i < op->args.size(); ++i) {
-      prep_seq_.emplace_back(Store(stack_shape_, cast(DataType::Int(64), 
op->args[i]),
+      prep_seq_.emplace_back(Store(scope.stack_shape_, cast(DataType::Int(64), 
op->args[i]),
                                    ConstInt32(stack_begin + i), 
const_true(1)));
     }
-    return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
+    return AddressOffset(scope.stack_shape_, DataType::Int(64), stack_begin);
   }
   // make array
   PrimExpr MakeArray(const CallNode* op) {
-    size_t idx = run_array_stack_;
-    run_array_stack_ += 1;
+    auto& scope = alloca_scope_.back();

Review comment:
       Done.

##########
File path: src/tir/transforms/lower_tvm_builtin.cc
##########
@@ -48,30 +48,54 @@ inline PrimExpr StackAlloca(std::string type, size_t num) {
 // These information are needed during codegen.
 class BuiltinLower : public StmtExprMutator {
  public:
-  Stmt Build(Stmt stmt) {
-    stack_shape_ = Var("stack_shape", DataType::Handle());
-    stack_array_ = Var("stack_array", DataType::Handle());
-    stack_value_ = Var("stack_value", DataType::Handle());
-    stack_tcode_ = Var("stack_tcode", DataType::Handle());
+  // Record stack frame for existing scope.
+  struct AllocaScope {
+    Var stack_shape_ = Var("stack_shape", DataType::Handle());

Review comment:
       Fixed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to