areusch commented on a change in pull request #8096:
URL: https://github.com/apache/tvm/pull/8096#discussion_r646828579



##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -44,52 +45,185 @@ namespace tvm {
 namespace relay {
 namespace backend {
 
+/**
+ * Struct to contain information about the intermediate tensors in the
+ * runner function
+ */
+struct StorageInfo {
+  /*! \brief storage integer identifier of the particular intermediate buffer 
*/
+  int sid;
+  /*! \brief exact size of the temporary */
+  int size_bytes;
+  /*! \brief device type of the intermediate tensor */
+  int dev_type;
+};
+
 using IntegerArray = Array<Integer>;
 using TargetsMap = std::unordered_map<int, Target>;
+using StorageMap = std::unordered_map<Expr, std::vector<StorageInfo>, 
runtime::ObjectPtrHash,
+                                      runtime::ObjectPtrEqual>;
 
-class AotReturnSidVisitor : public ExprVisitor {
+/**
+ * This is an on demand allocator for AOT. A new temporary
+ * (storage allocator identifier) is allocated for each operation.
+ */
+class AOTOnDemandAllocator : public ExprVisitor {
  public:
-  explicit AotReturnSidVisitor(Map<Expr, Array<IntegerArray>> 
storage_device_map)
-      : storage_device_map_{storage_device_map}, return_sid_{-1} {}
+  // run the visitor on a function.
+  void Run(const Function& func) {
+    node_device_map_ = CollectDeviceInfo(func);
 
-  IntegerArray FindReturnSid(Function func) {
-    VisitExpr(func->body);
-    return return_sid_;
+    for (Expr param : func->params) {
+      CreateStorage(param.operator->());
+    }
+
+    GetStorage(func->body);
   }
 
- protected:
-  void AssignReturnSid(Expr e) {
-    auto iter = storage_device_map_.find(e);
-    if (iter != storage_device_map_.end()) {
-      return_sid_ = (*iter).second[0];
+  std::vector<int> GetReturnIds() const { return return_ids_; }
+
+  StorageMap GetStorageMap() const { return storage_device_map_; }
+
+  void VisitExpr_(const ConstantNode* op) final {
+    CreateStorage(op);
+    AssignReturnSid(GetRef<Expr>(op));
+  }
+
+  void VisitExpr_(const CallNode* op) final {
+    // create token for the call node.
+    CreateStorage(op);
+    for (Expr arg : op->args) {
+      GetStorage(arg);
     }
+    AssignReturnSid(GetRef<Expr>(op));
   }
 
-  void VisitExpr_(const ConstantNode* cn) override {
-    ExprVisitor::VisitExpr_(cn);
-    AssignReturnSid(GetRef<Expr>(cn));
+  void VisitExpr_(const VarNode* op) final {
+    ExprVisitor::VisitExpr_(op);
+    AssignReturnSid(GetRef<Expr>(op));
   }
 
-  void VisitExpr_(const VarNode* vn) override {
-    ExprVisitor::VisitExpr_(vn);
-    AssignReturnSid(GetRef<Expr>(vn));
+  void VisitExpr_(const FunctionNode* op) final {
+    // do not recurse into sub function.
   }
 
-  void VisitExpr_(const CallNode* cn) override {
-    ExprVisitor::VisitExpr_(cn);
-    AssignReturnSid(GetRef<Expr>(cn));
+  void VisitExpr_(const GlobalVarNode* op) final {
+    // Do nothing.
   }
 
-  void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); }
+  void VisitExpr_(const OpNode* op) final {
+    // Do nothing.
+  }
+
+  void VisitExpr_(const TupleNode* op) final {
+    std::vector<StorageInfo> field_sids;
+    Expr expr = GetRef<Expr>(op);
+    for (Expr field : op->fields) {
+      auto sid = GetStorage(field);
+      field_sids.insert(field_sids.end(), sid.begin(), sid.end());
+    }
+
+    storage_device_map_[expr] = field_sids;
+    AssignReturnSid(expr);
+  }
 
-  void VisitExpr_(const TupleNode* tn) override {
-    ExprVisitor::VisitExpr_(tn);
-    AssignReturnSid(GetRef<Expr>(tn));
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    Expr expr = GetRef<Expr>(op);
+    const auto& sids = GetStorage(op->tuple);
+    ICHECK_LT(static_cast<size_t>(op->index), sids.size());
+    storage_device_map_[expr] = {sids[op->index]};
+    AssignReturnSid(expr);
   }
 
+  void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not 
supported."; }
+
+  void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "if is not 
supported."; }
+
  private:
-  Map<Expr, Array<IntegerArray>> storage_device_map_;
-  IntegerArray return_sid_;
+  void AssignReturnSid(Expr e) {
+    if (storage_device_map_.find(e) != storage_device_map_.end()) {
+      auto buffers = storage_device_map_[e];
+      std::vector<int> return_ids;
+      for (auto buffer : buffers) {
+        return_ids.push_back(buffer.sid);
+      }
+      return_ids_ = return_ids;
+    }
+  }
+  /*!
+   * \brief ceil(size/word_size) to get number of words.
+   * \param size The original size.
+   * \param word_size The element size.
+   */
+  static size_t DivRoundUp(size_t size, size_t word_size) {
+    return (size + word_size - 1) / word_size;
+  }
+  /*!
+   * \brief Get the memory requirement.
+   * \param prototype The prototype token.
+   * \return The required memory size.
+   */
+  size_t GetMemorySize(const TensorTypeNode* ttype) {

Review comment:
       prefer to include units if possible e.g. GetMemorySizeBytes or 
GetNumBytes

##########
File path: src/tir/transforms/storage_rewrite.cc
##########
@@ -138,6 +138,34 @@ class LinearAccessPatternFinder final : public 
StmtExprVisitor {
     if (op->op.same_as(builtin::address_of())) {
       const LoadNode* l = op->args[0].as<LoadNode>();
       this->VisitExpr(l->index);
+    } else if (op->op.same_as(builtin::tvm_call_cpacked())) {
+      // Recall that the arguments of a tvm_call_cpacked are passed as
+      // TVMValues. But a TVMValue is only a container, that points to
+      // a real buffer previously allocated. We need to signal that those
+      // buffers need to be live at the same time (i.e., cannot be overridden)
+      Array<PrimExpr> args = op->args;
+      for (auto arg : args) {
+        const VarNode* var = arg.as<VarNode>();
+        if (value_to_alloc_.find(var) != value_to_alloc_.end()) {
+          auto allocs = value_to_alloc_[var];
+          for (const VarNode* alloc : allocs) {
+            VisitExpr_(alloc);
+          }
+        } else {
+          this->VisitExpr(arg);
+        }
+      }
+    } else if (op->op.same_as(builtin::tvm_struct_set())) {

Review comment:
       are these the only two such builtins we need to care about? seems like 
any access to the data would be affected, no?

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -44,52 +45,185 @@ namespace tvm {
 namespace relay {
 namespace backend {
 
+/**
+ * Struct to contain information about the intermediate tensors in the
+ * runner function
+ */
+struct StorageInfo {
+  /*! \brief storage integer identifier of the particular intermediate buffer 
*/
+  int sid;
+  /*! \brief exact size of the temporary */
+  int size_bytes;
+  /*! \brief device type of the intermediate tensor */
+  int dev_type;
+};
+
 using IntegerArray = Array<Integer>;
 using TargetsMap = std::unordered_map<int, Target>;
+using StorageMap = std::unordered_map<Expr, std::vector<StorageInfo>, 
runtime::ObjectPtrHash,
+                                      runtime::ObjectPtrEqual>;
 
-class AotReturnSidVisitor : public ExprVisitor {
+/**
+ * This is an on demand allocator for AOT. A new temporary
+ * (storage allocator identifier) is allocated for each operation.
+ */
+class AOTOnDemandAllocator : public ExprVisitor {
  public:
-  explicit AotReturnSidVisitor(Map<Expr, Array<IntegerArray>> 
storage_device_map)
-      : storage_device_map_{storage_device_map}, return_sid_{-1} {}
+  // run the visitor on a function.
+  void Run(const Function& func) {
+    node_device_map_ = CollectDeviceInfo(func);
 
-  IntegerArray FindReturnSid(Function func) {
-    VisitExpr(func->body);
-    return return_sid_;
+    for (Expr param : func->params) {
+      CreateStorage(param.operator->());
+    }
+
+    GetStorage(func->body);
   }
 
- protected:
-  void AssignReturnSid(Expr e) {
-    auto iter = storage_device_map_.find(e);
-    if (iter != storage_device_map_.end()) {
-      return_sid_ = (*iter).second[0];
+  std::vector<int> GetReturnIds() const { return return_ids_; }
+
+  StorageMap GetStorageMap() const { return storage_device_map_; }
+
+  void VisitExpr_(const ConstantNode* op) final {
+    CreateStorage(op);
+    AssignReturnSid(GetRef<Expr>(op));
+  }
+
+  void VisitExpr_(const CallNode* op) final {
+    // create token for the call node.
+    CreateStorage(op);
+    for (Expr arg : op->args) {
+      GetStorage(arg);
     }
+    AssignReturnSid(GetRef<Expr>(op));
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    ExprVisitor::VisitExpr_(op);
+    AssignReturnSid(GetRef<Expr>(op));
   }
 
-  void VisitExpr_(const ConstantNode* cn) override {
-    ExprVisitor::VisitExpr_(cn);
-    AssignReturnSid(GetRef<Expr>(cn));
+  void VisitExpr_(const FunctionNode* op) final {
+    // do not recurse into sub function.
   }
 
-  void VisitExpr_(const VarNode* vn) override {
-    ExprVisitor::VisitExpr_(vn);
-    AssignReturnSid(GetRef<Expr>(vn));
+  void VisitExpr_(const GlobalVarNode* op) final {
+    // Do nothing.
   }
 
-  void VisitExpr_(const CallNode* cn) override {
-    ExprVisitor::VisitExpr_(cn);
-    AssignReturnSid(GetRef<Expr>(cn));
+  void VisitExpr_(const OpNode* op) final {
+    // Do nothing.
   }
 
-  void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); }
+  void VisitExpr_(const TupleNode* op) final {
+    std::vector<StorageInfo> field_sids;
+    Expr expr = GetRef<Expr>(op);
+    for (Expr field : op->fields) {
+      auto sid = GetStorage(field);
+      field_sids.insert(field_sids.end(), sid.begin(), sid.end());
+    }
 
-  void VisitExpr_(const TupleNode* tn) override {
-    ExprVisitor::VisitExpr_(tn);
-    AssignReturnSid(GetRef<Expr>(tn));
+    storage_device_map_[expr] = field_sids;
+    AssignReturnSid(expr);
   }
 
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    Expr expr = GetRef<Expr>(op);
+    const auto& sids = GetStorage(op->tuple);
+    ICHECK_LT(static_cast<size_t>(op->index), sids.size());
+    storage_device_map_[expr] = {sids[op->index]};
+    AssignReturnSid(expr);
+  }
+
+  void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not 
supported."; }
+
+  void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "if is not 
supported."; }
+
  private:
-  Map<Expr, Array<IntegerArray>> storage_device_map_;
-  IntegerArray return_sid_;
+  void AssignReturnSid(Expr e) {
+    if (storage_device_map_.find(e) != storage_device_map_.end()) {
+      auto buffers = storage_device_map_[e];
+      std::vector<int> return_ids;
+      for (auto buffer : buffers) {
+        return_ids.push_back(buffer.sid);
+      }
+      return_ids_ = return_ids;
+    }
+  }
+  /*!
+   * \brief ceil(size/word_size) to get number of words.
+   * \param size The original size.
+   * \param word_size The element size.
+   */
+  static size_t DivRoundUp(size_t size, size_t word_size) {
+    return (size + word_size - 1) / word_size;
+  }
+  /*!
+   * \brief Get the memory requirement.
+   * \param prototype The prototype token.
+   * \return The required memory size.
+   */
+  size_t GetMemorySize(const TensorTypeNode* ttype) {
+    ICHECK(ttype != nullptr);
+    size_t size = 1;
+    for (IndexExpr dim : ttype->shape) {
+      const int64_t* pval = tir::as_const_int(dim);
+      ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape 
" << ttype->shape;
+      ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative 
shape" << *pval;
+      size *= static_cast<size_t>(pval[0]);
+    }
+    size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8);
+    return size;
+  }
+  /*!
+   * \brief Get the necessary storage for the expression.
+   * \param expr The expression.
+   * \return The corresponding token.
+   */
+  std::vector<StorageInfo> GetStorage(const Expr& expr) {
+    this->VisitExpr(expr);
+    auto it = storage_device_map_.find(expr);
+    ICHECK(it != storage_device_map_.end());
+    return it->second;
+  }
+
+  /*!
+   * \brief Create storage for the expression.
+   * \param expr The expression.
+   */
+  void CreateStorage(const ExprNode* op) {
+    std::vector<StorageInfo> buffers;
+    Expr expr = GetRef<Expr>(op);
+    int device_type = node_device_map_.count(GetRef<Expr>(op)) ? 
node_device_map_[expr]->value : 0;
+    if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        const auto* ttype = t.as<TensorTypeNode>();
+        ICHECK(ttype);
+        StorageInfo buffer;
+        buffer.sid = sid_++;
+        buffer.size_bytes = GetMemorySize(ttype);
+        buffer.dev_type = device_type;
+        buffers.push_back(buffer);
+      }
+    } else {
+      const auto* ttype = op->checked_type().as<TensorTypeNode>();
+      ICHECK(ttype);
+      StorageInfo buffer;
+      buffer.sid = sid_++;
+      buffer.size_bytes = GetMemorySize(ttype);
+      buffer.dev_type = device_type;
+      buffers.push_back(buffer);
+    }
+    storage_device_map_[expr] = buffers;
+  }
+  /*! \brief mapping of expression -> storageInfo*/
+  StorageMap storage_device_map_;
+  /*! \brief mapping of expression -> device type*/
+  Map<Expr, Integer> node_device_map_;
+  /*! \brief current id of the temporary allocated*/
+  int sid_{0};

Review comment:
       maybe name this `next_available_sid_`?

##########
File path: src/tir/transforms/storage_rewrite.cc
##########
@@ -138,6 +138,34 @@ class LinearAccessPatternFinder final : public 
StmtExprVisitor {
     if (op->op.same_as(builtin::address_of())) {
       const LoadNode* l = op->args[0].as<LoadNode>();
       this->VisitExpr(l->index);
+    } else if (op->op.same_as(builtin::tvm_call_cpacked())) {
+      // Recall that the arguments of a tvm_call_cpacked are passed as
+      // TVMValues. But a TVMValue is only a container, that points to
+      // a real buffer previously allocated. We need to signal that those
+      // buffers need to be live at the same time (i.e., cannot be overridden)

Review comment:
       nit: change to say something like: (i.e. cannot be overwritten during 
the function call). 

##########
File path: tests/python/relay/aot/test_crt_aot.py
##########
@@ -364,5 +364,27 @@ def test_byoc_utvm(use_calculated_workspaces):
     compile_and_run(mod, input_list, output_list, use_calculated_workspaces)
 
 
+def test_quant_mobilenet_tfl():

Review comment:
       can you also do this for the test case below?

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -44,52 +45,185 @@ namespace tvm {
 namespace relay {
 namespace backend {
 
+/**
+ * Struct to contain information about the intermediate tensors in the
+ * runner function
+ */
+struct StorageInfo {
+  /*! \brief storage integer identifier of the particular intermediate buffer 
*/
+  int sid;
+  /*! \brief exact size of the temporary */
+  int size_bytes;
+  /*! \brief device type of the intermediate tensor */
+  int dev_type;
+};
+
 using IntegerArray = Array<Integer>;
 using TargetsMap = std::unordered_map<int, Target>;
+using StorageMap = std::unordered_map<Expr, std::vector<StorageInfo>, 
runtime::ObjectPtrHash,
+                                      runtime::ObjectPtrEqual>;
 
-class AotReturnSidVisitor : public ExprVisitor {
+/**
+ * This is an on demand allocator for AOT. A new temporary
+ * (storage allocator identifier) is allocated for each operation.
+ */
+class AOTOnDemandAllocator : public ExprVisitor {
  public:
-  explicit AotReturnSidVisitor(Map<Expr, Array<IntegerArray>> 
storage_device_map)
-      : storage_device_map_{storage_device_map}, return_sid_{-1} {}
+  // run the visitor on a function.
+  void Run(const Function& func) {
+    node_device_map_ = CollectDeviceInfo(func);
 
-  IntegerArray FindReturnSid(Function func) {
-    VisitExpr(func->body);
-    return return_sid_;
+    for (Expr param : func->params) {
+      CreateStorage(param.operator->());
+    }
+
+    GetStorage(func->body);
   }
 
- protected:
-  void AssignReturnSid(Expr e) {
-    auto iter = storage_device_map_.find(e);
-    if (iter != storage_device_map_.end()) {
-      return_sid_ = (*iter).second[0];
+  std::vector<int> GetReturnIds() const { return return_ids_; }
+
+  StorageMap GetStorageMap() const { return storage_device_map_; }
+
+  void VisitExpr_(const ConstantNode* op) final {
+    CreateStorage(op);
+    AssignReturnSid(GetRef<Expr>(op));
+  }
+
+  void VisitExpr_(const CallNode* op) final {
+    // create token for the call node.
+    CreateStorage(op);
+    for (Expr arg : op->args) {
+      GetStorage(arg);
     }
+    AssignReturnSid(GetRef<Expr>(op));
   }
 
-  void VisitExpr_(const ConstantNode* cn) override {
-    ExprVisitor::VisitExpr_(cn);
-    AssignReturnSid(GetRef<Expr>(cn));
+  void VisitExpr_(const VarNode* op) final {
+    ExprVisitor::VisitExpr_(op);
+    AssignReturnSid(GetRef<Expr>(op));
   }
 
-  void VisitExpr_(const VarNode* vn) override {
-    ExprVisitor::VisitExpr_(vn);
-    AssignReturnSid(GetRef<Expr>(vn));
+  void VisitExpr_(const FunctionNode* op) final {
+    // do not recurse into sub function.
   }
 
-  void VisitExpr_(const CallNode* cn) override {
-    ExprVisitor::VisitExpr_(cn);
-    AssignReturnSid(GetRef<Expr>(cn));
+  void VisitExpr_(const GlobalVarNode* op) final {
+    // Do nothing.
   }
 
-  void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); }
+  void VisitExpr_(const OpNode* op) final {
+    // Do nothing.
+  }
+
+  void VisitExpr_(const TupleNode* op) final {
+    std::vector<StorageInfo> field_sids;
+    Expr expr = GetRef<Expr>(op);
+    for (Expr field : op->fields) {
+      auto sid = GetStorage(field);
+      field_sids.insert(field_sids.end(), sid.begin(), sid.end());
+    }
+
+    storage_device_map_[expr] = field_sids;
+    AssignReturnSid(expr);
+  }
 
-  void VisitExpr_(const TupleNode* tn) override {
-    ExprVisitor::VisitExpr_(tn);
-    AssignReturnSid(GetRef<Expr>(tn));
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    Expr expr = GetRef<Expr>(op);
+    const auto& sids = GetStorage(op->tuple);
+    ICHECK_LT(static_cast<size_t>(op->index), sids.size());
+    storage_device_map_[expr] = {sids[op->index]};
+    AssignReturnSid(expr);
   }
 
+  void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not 
supported."; }
+
+  void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "if is not 
supported."; }
+
  private:
-  Map<Expr, Array<IntegerArray>> storage_device_map_;
-  IntegerArray return_sid_;
+  void AssignReturnSid(Expr e) {
+    if (storage_device_map_.find(e) != storage_device_map_.end()) {
+      auto buffers = storage_device_map_[e];
+      std::vector<int> return_ids;

Review comment:
       nit: perhaps slightly faster to mutate return_ids_ rather than assign?

##########
File path: src/tir/transforms/storage_rewrite.cc
##########
@@ -206,6 +234,8 @@ class LinearAccessPatternFinder final : public 
StmtExprVisitor {
   bool in_thread_env_{false};
   // The scope stack.
   std::vector<StmtEntry> scope_;
+  // This is a map to connect TVMValues to real allocations
+  std::unordered_map<const VarNode*, std::vector<const VarNode*>> 
value_to_alloc_;

Review comment:
       could you update the comment to better explain the keys and values, and 
the rules for when something should be added here?

##########
File path: tests/python/relay/aot/test_crt_aot.py
##########
@@ -364,5 +364,27 @@ def test_byoc_utvm(use_calculated_workspaces):
     compile_and_run(mod, input_list, output_list, use_calculated_workspaces)
 
 
+def test_quant_mobilenet_tfl():

Review comment:
       yes, it would be great to explain the thing we are trying to test using 
mobilenet, and if possible, contrive a testcase not based on data from the 
internet.

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -44,52 +45,185 @@ namespace tvm {
 namespace relay {
 namespace backend {
 
+/**
+ * Struct to contain information about the intermediate tensors in the
+ * runner function
+ */
+struct StorageInfo {
+  /*! \brief storage integer identifier of the particular intermediate buffer 
*/
+  int sid;
+  /*! \brief exact size of the temporary */
+  int size_bytes;
+  /*! \brief device type of the intermediate tensor */
+  int dev_type;
+};
+
 using IntegerArray = Array<Integer>;
 using TargetsMap = std::unordered_map<int, Target>;
+using StorageMap = std::unordered_map<Expr, std::vector<StorageInfo>, 
runtime::ObjectPtrHash,
+                                      runtime::ObjectPtrEqual>;
 
-class AotReturnSidVisitor : public ExprVisitor {
+/**
+ * This is an on demand allocator for AOT. A new temporary
+ * (storage allocator identifier) is allocated for each operation.
+ */
+class AOTOnDemandAllocator : public ExprVisitor {
  public:
-  explicit AotReturnSidVisitor(Map<Expr, Array<IntegerArray>> 
storage_device_map)
-      : storage_device_map_{storage_device_map}, return_sid_{-1} {}
+  // run the visitor on a function.
+  void Run(const Function& func) {
+    node_device_map_ = CollectDeviceInfo(func);
 
-  IntegerArray FindReturnSid(Function func) {
-    VisitExpr(func->body);
-    return return_sid_;
+    for (Expr param : func->params) {
+      CreateStorage(param.operator->());
+    }
+
+    GetStorage(func->body);
   }
 
- protected:
-  void AssignReturnSid(Expr e) {
-    auto iter = storage_device_map_.find(e);
-    if (iter != storage_device_map_.end()) {
-      return_sid_ = (*iter).second[0];
+  std::vector<int> GetReturnIds() const { return return_ids_; }
+
+  StorageMap GetStorageMap() const { return storage_device_map_; }
+
+  void VisitExpr_(const ConstantNode* op) final {
+    CreateStorage(op);
+    AssignReturnSid(GetRef<Expr>(op));
+  }
+
+  void VisitExpr_(const CallNode* op) final {
+    // create token for the call node.
+    CreateStorage(op);
+    for (Expr arg : op->args) {
+      GetStorage(arg);
     }
+    AssignReturnSid(GetRef<Expr>(op));
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    ExprVisitor::VisitExpr_(op);
+    AssignReturnSid(GetRef<Expr>(op));
   }
 
-  void VisitExpr_(const ConstantNode* cn) override {
-    ExprVisitor::VisitExpr_(cn);
-    AssignReturnSid(GetRef<Expr>(cn));
+  void VisitExpr_(const FunctionNode* op) final {
+    // do not recurse into sub function.
   }
 
-  void VisitExpr_(const VarNode* vn) override {
-    ExprVisitor::VisitExpr_(vn);
-    AssignReturnSid(GetRef<Expr>(vn));
+  void VisitExpr_(const GlobalVarNode* op) final {
+    // Do nothing.
   }
 
-  void VisitExpr_(const CallNode* cn) override {
-    ExprVisitor::VisitExpr_(cn);
-    AssignReturnSid(GetRef<Expr>(cn));
+  void VisitExpr_(const OpNode* op) final {
+    // Do nothing.
   }
 
-  void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); }
+  void VisitExpr_(const TupleNode* op) final {
+    std::vector<StorageInfo> field_sids;
+    Expr expr = GetRef<Expr>(op);
+    for (Expr field : op->fields) {
+      auto sid = GetStorage(field);
+      field_sids.insert(field_sids.end(), sid.begin(), sid.end());
+    }
 
-  void VisitExpr_(const TupleNode* tn) override {
-    ExprVisitor::VisitExpr_(tn);
-    AssignReturnSid(GetRef<Expr>(tn));
+    storage_device_map_[expr] = field_sids;
+    AssignReturnSid(expr);
   }
 
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    Expr expr = GetRef<Expr>(op);
+    const auto& sids = GetStorage(op->tuple);
+    ICHECK_LT(static_cast<size_t>(op->index), sids.size());
+    storage_device_map_[expr] = {sids[op->index]};
+    AssignReturnSid(expr);
+  }
+
+  void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not 
supported."; }
+
+  void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "if is not 
supported."; }

Review comment:
       nit: "let is not supported"




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to