Mousius commented on a change in pull request #8697:
URL: https://github.com/apache/tvm/pull/8697#discussion_r687643006



##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -365,155 +363,21 @@ class AOTExecutorCodegen : public ExprVisitor {
     return ss.str();
   }
 
-  /*!
-   * \brief Update the "main" control function's metadata
-   *
-   * \param func The main function that contains calls to operator tir 
primitive functions
-   */
-  void UpdateMainWorkspaceSize(const tir::PrimFunc& primfunc, const 
relay::Function& func) {
-    auto workspace_byte_alignment = 
target_host_->GetAttr<Integer>("workspace-byte-alignment")
-                                        
.value_or(tvm::runtime::kDefaultWorkspaceAlignment);
-    Integer workspace_size = CalculateWorkspaceBytes(primfunc, 
workspace_byte_alignment);
-    // Populate FunctionInfo
-    auto fi_node = make_object<FunctionInfoNode>();
-    // Initialize all target workspaces to zero
-    for (const auto& kv : targets_) {
-      auto tgt = kv.second;
-      fi_node->workspace_sizes.Set(tgt, 0);
-    }
-    fi_node->workspace_sizes.Set(target_host_, workspace_size);
-    fi_node->relay_primfuncs.Set(target_host_, func);
-
-    int64_t io_size = 0;
-    for (const auto& input : input_vars_) {
-      io_size += CalculateRelayExprSizeBytes(input->checked_type());
-    }
-    io_size += CalculateRelayExprSizeBytes(func->body->checked_type());
-    fi_node->io_sizes.Set(target_host_, io_size);
-
-    int64_t const_size = 0;
-    for (const auto& kv : params_by_expr_) {
-      const_size += CalculateRelayExprSizeBytes(kv.first->checked_type());
-    }
-    fi_node->constant_sizes.Set(target_host_, const_size);
-    function_metadata_.Set(String(runtime::symbol::tvm_module_main), 
FunctionInfo(fi_node));
-  }
-
-  /*!
-   * \brief Update the function metadata for a given cached function and its 
relay
-   * primitive function.
-   *
-   * \param cfunc The cached function as provided the by the compile engine
-   * \param relay_func The source relay primitive function
-   * \param relay_target The target associated with relay primitive function
-   */
-  void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& 
relay_func,
-                              const Target& relay_target) {
-    auto fi_node = make_object<FunctionInfoNode>();
-    for (const auto& kv : cfunc->funcs->functions) {
-      auto primfunc = Downcast<tir::PrimFunc>(kv.second);
-      auto workspace_byte_alignment =
-          
target_host_->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
-      Integer workspace_size = CalculateWorkspaceBytes(primfunc, 
workspace_byte_alignment);
-      Target primfunc_target = relay_target;
-      if (primfunc->attrs->dict.count("target")) {
-        primfunc_target = Downcast<Target>(primfunc->attrs->dict["target"]);
-      }
-      fi_node->workspace_sizes.Set(primfunc_target, workspace_size);
-      // Calculating size for I/O
-      for (auto const& param : primfunc->params) {
-        auto p_shape = primfunc->buffer_map[param]->shape;
-        int num_of_elements = 1;
-        for (const auto& dim_index_expr : p_shape) {
-          if (dim_index_expr->IsInstance<IntImmNode>()) {
-            num_of_elements *= dim_index_expr.as<IntImmNode>()->value;
-          } else {
-            // If shape is dynamic, we cannot calculate workspace in compile 
time.
-            num_of_elements = 0;
-          }
-        }
-        int element_size = primfunc->buffer_map[param]->dtype.bytes();
-        fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements);
-      }
-      fi_node->constant_sizes.Set(primfunc_target, 0);
-      fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
-      fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
-    }
-    function_metadata_.Set(cfunc->prim_fn_var->name_hint, 
FunctionInfo(fi_node));
-  }
-
   void VisitExpr_(const CallNode* op) override {
     // Descend the call tree
     for (auto arg : op->args) {
       VisitExpr(arg);
     }
 
-    Expr expr = GetRef<Expr>(op);
-    Function func;
     if (op->op.as<OpNode>()) {
       LOG(FATAL) << "Operators should be transformed away; try applying"
                  << "the fuse_ops transformation to the expression.";
     } else if (op->op.as<GlobalVarNode>()) {
-      LOG(FATAL) << "Not implemented";
-    } else if (op->op.as<FunctionNode>()) {
-      func = GetRef<Function>(op->op.as<FunctionNode>());
+      GlobalVar node = GetRef<GlobalVar>(op->op.as<GlobalVarNode>());
+      CreateFuncCall(GetRef<Call>(op), node->name_hint);
     } else {
       LOG(FATAL) << "TVM runtime does not support calls to " << 
op->op->GetTypeKey();
     }
-    if (!func->HasNonzeroAttr(attr::kPrimitive)) {

Review comment:
       Everything is already lowered at this point as it's been through 
`LowerTE` before this runs, so we don't have to make the assumption - it's 
guaranteed :smile_cat: 
   
   I'd also suggest that we don't add defensive code which we can't craft a way 
to invoke?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to