This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0812c07  Change Call with TIRCallAttrs to call_lowered op (#9312)
0812c07 is described below

commit 0812c078bfc8595d079fbb15cdc4a64bf1a4def2
Author: Lily Orth-Smith <[email protected]>
AuthorDate: Wed Nov 10 14:20:01 2021 -0800

    Change Call with TIRCallAttrs to call_lowered op (#9312)
    
    * Introduce call_lowered op
    
    Add op vm.call_tir
    
    Change from checking if CallNode has CallTIRAttrs to checking if the Op is 
vm.call_tir
    
    Change device_domains to use vm.call_tir op more explicitly
    
    Fixed issue in type checker, now have seg fault :(
    
    Fix typo -- most of VM tests pass now
    
    Interpreter now deals with call_tir properly
    
    Fix typo in te_compiler
    
    Use InvokeTVMOp and CallTIR
    
    Add some checks to graph_plan_memory.cc
    
    Make GetToken skip function types
    
    C++ TESTS PASS WOOHOO
    
    Remove prints
    
    formatting
    
    vm.call_tir -> call_tir and more comment removals
    
    call_tir -> call_lowered
    
    fix lint
    
    clang format
    
    Remove compute from non computational vm ops
    
    missed some semicolons in prev commit
    
    Fix warning
    
    Move call_lowered to relay/op/call/call.cc and rename util func
    
    Add helper fn that returns lowered_call op
    
    fix import order
    
    clang format
    
    Add constraint to call_lowered type rel
    
    clean up empty token vector
    
    comment
    
    Move CallTIRAttrs to include/tvm/relay/attrs/call.h
    
    Rename TIRCallAttrs as CallLoweredAttrs
    
    lint
    
    Add helper for extracting func and args from call_lowered
    
    Change graph_executor_codegen to use helper function
    
    Update interpreter to use helper
    
    Fix device_domains.cc -- could still use cleanup, also I am not sure why 
there are still direct calls to primfns in DomainforCallee
    
    Clean up DeviceCopyProps and lint
    
    lint
    
    return CallLoweredAttrs with the extern func
    
    comment
    
    note in comment
    
    Progress & notes. Realized that I am not handling externs correctly
    
    not sure why this ever worked before?
    
    Clean up CreateFuncCall signature, notes
    
    comments
    
    Fix extern function handling
    
    extern_function -> extern_func
    
    fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls
    
    yay passes AOT tests!
    
    formatting and comment removal
    
    cleanup
    
    Introduce call_lowered op
    
    * lint
    
    * Fix AOT to deal with externs
    
    * add const auto&
    
    * Fix aot crt test
---
 include/tvm/relay/attrs/annotation.h               |  11 --
 .../op/vm/vm.h => include/tvm/relay/attrs/call.h   |  30 ++--
 src/relay/backend/aot_executor_codegen.cc          |  77 +++++++---
 .../contrib/example_target_hooks/relay_to_tir.cc   |  12 +-
 src/relay/backend/graph_executor_codegen.cc        |  89 ++++++-----
 src/relay/backend/graph_plan_memory.cc             |  52 ++++---
 src/relay/backend/interpreter.cc                   | 152 +++++++++---------
 src/relay/backend/te_compiler.cc                   | 169 +++++++++++----------
 src/relay/op/call/call.cc                          | 116 ++++++++++++++
 src/relay/op/call/call.h                           |  74 +++++++++
 src/relay/op/memory/device_copy.cc                 |  17 +++
 src/relay/op/vm/vm.h                               |   2 +-
 src/relay/transforms/device_domains.cc             |  33 ++--
 src/relay/transforms/memory_alloc.cc               |  16 +-
 14 files changed, 574 insertions(+), 276 deletions(-)

diff --git a/include/tvm/relay/attrs/annotation.h 
b/include/tvm/relay/attrs/annotation.h
index 85ac3f36..f88ca8e 100644
--- a/include/tvm/relay/attrs/annotation.h
+++ b/include/tvm/relay/attrs/annotation.h
@@ -116,17 +116,6 @@ struct CompilerAttrs : public 
tvm::AttrsNode<CompilerAttrs> {
   }
 };
 
-/*!
- * \brief Metadata for calls to TIR functions, useful for program analysis 
crossing Relay and TIR.
- */
-struct TIRCallAttrs : public tvm::AttrsNode<TIRCallAttrs> {
-  /*! \brief The metadata attached to the call node. */
-  Map<String, ObjectRef> metadata;
-
-  TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") {
-    TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function 
call.");
-  }
-};
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/op/vm/vm.h b/include/tvm/relay/attrs/call.h
similarity index 57%
copy from src/relay/op/vm/vm.h
copy to include/tvm/relay/attrs/call.h
index 802c810..2b02c6a 100644
--- a/src/relay/op/vm/vm.h
+++ b/include/tvm/relay/attrs/call.h
@@ -18,23 +18,31 @@
  */
 
 /*!
- * \file src/relay/op/vm/vm.h
- * \brief Dialect operators for Relay VM.
+ * \file tvm/relay/attrs/call.h
+ * \brief Attribute for call_lowered operator.
  */
-#ifndef TVM_RELAY_OP_VM_VM_H_
-#define TVM_RELAY_OP_VM_VM_H_
+#ifndef TVM_RELAY_ATTRS_CALL_H_
+#define TVM_RELAY_ATTRS_CALL_H_
 
-#include "tvm/relay/expr.h"
+#include <tvm/ir/attrs.h>
+
+#include <string>
 
 namespace tvm {
 namespace relay {
 
-Expr InvokeTVMOp(Expr func, Expr inputs, Expr outputs);
-Expr ShapeFunc(Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> 
is_input);
-Expr ShapeOf(Expr expr);
-Expr ReshapeTensor(Expr data, Expr shape, Array<PrimExpr> newshape);
+/*!
+ * \brief Metadata for calls to TIR functions, useful for program analysis 
crossing Relay and TIR.
+ */
+struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
+  /*! \brief The metadata attached to the call node. */
+  Map<String, ObjectRef> metadata;
+
+  TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
+    TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered 
function call.");
+  }
+};
 
 }  // namespace relay
 }  // namespace tvm
-
-#endif  // TVM_RELAY_OP_VM_VM_H_
+#endif  // TVM_RELAY_ATTRS_CALL_H_
diff --git a/src/relay/backend/aot_executor_codegen.cc 
b/src/relay/backend/aot_executor_codegen.cc
index 7e57022..58bcccf 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -24,6 +24,7 @@
 
 #include <tvm/ir/module.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/object.h>
@@ -40,6 +41,7 @@
 #include <vector>
 
 #include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
 #include "../transforms/device_aware_visitors.h"
 #include "./te_compiler.h"
 #include "./utils.h"
@@ -72,14 +74,34 @@ class AOTOnDemandAllocator : public 
transform::DeviceAwareExprVisitor {
     AssignReturnSid(GetRef<Expr>(op));
   }
 
-  void DeviceAwareVisitExpr_(const CallNode* op) final {
-    // create token for the call node.
-    VisitExpr(op->op);
-    CreateStorage(op);
-    for (Expr arg : op->args) {
+  void DeviceAwareVisitExpr_(const CallNode* call_node) final {
+    // AOTOnDemandAllocator is run both before and after lowering, so we need 
to handle the case
+    // where the op of the call is a generic function
+
+    Expr func;
+    Array<Expr> args;
+
+    if (call_node->op == CallLoweredOp()) {
+      CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+      func = call_lowered_props.lowered_func;
+      args = call_lowered_props.arguments;
+    } else {  // Relay functions that have not been lowered and lowered extern 
functions
+      func = call_node->op;
+      args = call_node->args;
+      if (call_node->op.as<GlobalVarNode>()) {  // Lowered extern function
+        ICHECK(!(call_node->attrs.defined())) << "Extern functions should have 
null attributes.";
+      } else {  // Relay function which has not been lowered yet
+        ICHECK(call_node->op.as<FunctionNode>())
+            << "Expected the call to be to a lowered primfunc, a lowered 
extern function or a "
+               "unlowered Relay function.";
+      }
+    }
+    VisitExpr(func);
+    CreateStorage(call_node);
+    for (const Expr& arg : args) {
       GetStorage(arg);
     }
-    AssignReturnSid(GetRef<Expr>(op));
+    AssignReturnSid(GetRef<Expr>(call_node));
   }
 
   void VisitExpr_(const VarNode* op) final { 
AssignReturnSid(GetRef<Expr>(op)); }
@@ -287,13 +309,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
   }
 
   /*!
-   * brief Call a function with a given name
+   * brief Create a function call
+   * \param call_lowered_props The lowered function and the arguments to call 
it with
+   * \param call The call we got func and args from
    */
-  void CreateFuncCall(Call call, std::string func_name) {
+  void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) {
+    std::string func_name = call_lowered_props.lowered_func->name_hint;
+
     tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
     std::vector<tir::Stmt> create_func_call_stmts;
+
     // Pack the inputs
-    for (Expr arg : call->args) {
+    for (const Expr& arg : call_lowered_props.arguments) {
       if (params_by_expr_.find(arg) != params_by_expr_.end()) {
         auto param_handle = tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::lookup_param(),
                                            
{tir::StringImm(params_by_expr_[arg])});
@@ -371,21 +398,25 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     return ss.str();
   }
 
-  void VisitExpr_(const CallNode* op) override {
+  void VisitExpr_(const CallNode* call_node) override {
     // Descend the call tree
-    for (auto arg : op->args) {
-      VisitExpr(arg);
-    }
-
-    if (op->op.as<OpNode>()) {
-      LOG(FATAL) << "Operators should be transformed away; try applying"
-                 << "the fuse_ops transformation to the expression.";
-    } else if (op->op.as<GlobalVarNode>()) {
-      GlobalVar node = GetRef<GlobalVar>(op->op.as<GlobalVarNode>());
-      CreateFuncCall(GetRef<Call>(op), node->name_hint);
+    CallLoweredProps call_lowered_props;
+    if (const auto* gvn = call_node->op.as<GlobalVarNode>()) {  // Lowered 
extern function
+      ICHECK(!(call_node->attrs.defined())) << "Extern functions should have 
null attributes.";
+      for (const auto& arg : call_node->args) {
+        VisitExpr(arg);
+      }
+      call_lowered_props = CallLoweredProps{GetRef<GlobalVar>(gvn), 
call_node->args, {}};
     } else {
-      LOG(FATAL) << "TVM runtime does not support calls to " << 
op->op->GetTypeKey();
+      ICHECK(call_node->op == CallLoweredOp()) << "Operators should be 
transformed away; Try "
+                                                  "applying the fuse_ops 
transformation to the "
+                                                  "expression.";
+      call_lowered_props = GetCallLoweredProps(call_node);
+      for (const auto& arg : call_lowered_props.arguments) {
+        VisitExpr(arg);
+      }
     }
+    CreateFuncCall(call_lowered_props, GetRef<Call>(call_node));
   }
 
   void VisitExpr_(const VarNode* op) override {
@@ -443,7 +474,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
   }
   void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); 
}
   void VisitExpr_(const OpNode* op) override {
-    LOG(FATAL) << "All OpNodes should have been expanded";
+    if (GetRef<Op>(op) != CallLoweredOp()) {
+      LOG(FATAL) << "All OpNodes except for call_lowered should have been 
expanded";
+    }
   }
   void VisitExpr_(const IfNode* op) override {
     LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's 
Codegen is called";
diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc 
b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
index cae2021..c41399e 100644
--- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
+++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
@@ -17,14 +17,18 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include <tvm/relay/attrs/call.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
+#include <tvm/runtime/memory.h>
 #include <tvm/tir/buffer.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/op.h>
 
+#include "../../../op/call/call.h"
+
 namespace tvm {
 namespace relay {
 namespace contrib {
@@ -109,7 +113,13 @@ class ConvertAddToSubtract : public MixedModeMutator {
         GlobalVar new_global_var(func_name.value());
         new_global_var->checked_type_ = func->checked_type();
         ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
-        return Call(new_global_var, call->args, call->attrs, call->type_args, 
call->span);
+
+        // Since we are replacing the Relay function with a call to a TIR 
function, we must use the
+        // call_lowered op.
+        auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+        call_lowered_attrs->metadata.Set("relay_attrs", call->attrs);
+        return CallLowered(std::move(new_global_var), call->args,
+                           std::move(Attrs(call_lowered_attrs)), 
call->type_args, call->span);
       }
     }
 
diff --git a/src/relay/backend/graph_executor_codegen.cc 
b/src/relay/backend/graph_executor_codegen.cc
index d32ded3..ac3c835 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -26,6 +26,7 @@
 #include <dmlc/json.h>
 #include <tvm/ir/module.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/object.h>
@@ -37,6 +38,7 @@
 #include <vector>
 
 #include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
 #include "../transforms/device_aware_visitors.h"
 #include "./te_compiler.h"
 #include "./utils.h"
@@ -403,64 +405,75 @@ class GraphExecutorCodegen : public 
backend::MemoizedExprTranslator<std::vector<
     return lhs_storage_id == rhs_storage_id;
   }
 
-  std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const 
std::string& func_name,
-                                             GraphAttrs attrs) {
+  std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* call_node, 
GraphAttrs attrs) {
+    Call call = GetRef<Call>(call_node);
     std::vector<GraphNodeRef> inputs;
-    for (auto arg : op->args) {
-      auto res = VisitExpr(arg);
-      for (auto nr : res) {
-        inputs.push_back(nr);
-      }
-    }
+    std::string func_name;
 
-    /// An adapted version of the storage optimization for the time being.
-    bool reshape_only = false;
-    if (op->attrs.defined()) {
-      if (auto tir_call_attrs = op->attrs.as<TIRCallAttrs>()) {
-        Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
-        if (metadata.count(attr::kReshapeOnly) &&
-            Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 1) {
-          reshape_only = true;
-        }
+    if (call->op == CallLoweredOp()) {
+      // Extract function and arguments from the call_lowered op
+      CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
 
-        auto relay_attrs = 
Downcast<DictAttrs>(tir_call_attrs->metadata["relay_attrs"]);
+      func_name = call_lowered_props.lowered_func->name_hint;
 
-        for (auto p : relay_attrs->dict) {
-          if (p.second.as<StringObj>()) {
-            attrs[p.first] = std::string(Downcast<String>(p.second));
+      for (const Expr& arg : call_lowered_props.arguments) {
+        for (auto n : VisitExpr(arg)) {
+          inputs.push_back(n);
+        }
+      }
+      if (call_lowered_props.attrs.metadata.count("relay_attrs")) {
+        if (auto relay_attrs =
+                
call_lowered_props.attrs.metadata["relay_attrs"].as<DictAttrsNode>()) {
+          for (auto p : relay_attrs->dict) {
+            if (p.second.as<StringObj>()) {
+              attrs[p.first] = std::string(Downcast<String>(p.second));
+            }
           }
         }
       }
-    }
-
-    if (reshape_only && ShareSameStorage(GetRef<Expr>(op), op->args[0])) {
-      auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), 
"__nop", inputs, attrs);
-      return AddNode(node, GetRef<Expr>(op));
+      bool reshape_only = false;
+      if (call_lowered_props.attrs.metadata.count(attr::kReshapeOnly) &&
+          
Downcast<tvm::Integer>(call_lowered_props.attrs.metadata[attr::kReshapeOnly])->value
 ==
+              1) {
+        reshape_only = true;
+      }
+      if (reshape_only &&
+          ShareSameStorage(GetRef<Expr>(call_node), 
call_lowered_props.arguments[0])) {
+        auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), 
"__nop", inputs, attrs);
+        return AddNode(node, call);
+      }
+    } else if (!call_node->attrs.defined()) {  // Call is an extern function
+      std::cout << "call_node: \n" << PrettyPrint(call) << std::endl;
+      const auto* func = call_node->op.as<GlobalVarNode>();
+      ICHECK(func) << "Expected the operator to be a global var, but got "
+                   << call_node->op->GetTypeKey();  // getting a relay fn 
here, not sure why.
+      func_name = func->name_hint;
+
+      for (const Expr& arg : call_node->args) {
+        for (auto n : VisitExpr(arg)) {
+          inputs.push_back(n);
+        }
+      }
+    } else {
+      LOG(FATAL) << "Non-primitive-call nodes should have been transformed 
away.\n"
+                 << "The graph executor code generator expects all calls to be 
call_lowered, "
+                 << "but found: " << std::endl
+                 << PrettyPrint(call);
     }
 
     // Compute the operator name, because we used the get unique name when 
generating the kernel.
     auto op_name = _GetUniqueName(func_name);
     auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, 
inputs, attrs);
-    return AddNode(node, GetRef<Expr>(op));
+    return AddNode(node, call);
   }
 
   std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
-    relay::Call call = GetRef<Call>(call_node);
     auto props = GetOnDeviceProps(call_node);
     if (props.body.defined()) {
       // See through "on_device" calls.
       return VisitExpr(props.body);
     }
-
-    const auto* global_node = call->op.as<GlobalVarNode>();
-    ICHECK(global_node)
-        << "Non-primitive-call nodes should have been transformed away.\n"
-        << "The graph executor code generator expects all calls to have their 
callee "
-           "normalized to a GlobalVar, but found:"
-        << std::endl
-        << PrettyPrint(call);
-    auto prim_fn_name = global_node->name_hint;
-    return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs());
+    return GraphAddCallNode(call_node, GraphAttrs());
   }
 
   std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
diff --git a/src/relay/backend/graph_plan_memory.cc 
b/src/relay/backend/graph_plan_memory.cc
index 961252a..4031dfd 100644
--- a/src/relay/backend/graph_plan_memory.cc
+++ b/src/relay/backend/graph_plan_memory.cc
@@ -24,6 +24,7 @@
  */
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
@@ -32,6 +33,7 @@
 
 #include "../../support/arena.h"
 #include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
 #include "../op/memory/memory.h"
 #include "../transforms/device_aware_visitors.h"
 #include "./utils.h"
@@ -139,6 +141,8 @@ class StorageAllocaBaseVisitor : public 
transform::DeviceAwareExprVisitor {
  protected:
   /*! \brief internal token map */
   std::unordered_map<const ExprNode*, std::vector<StorageToken*>> token_map_;
+  /*! \brief empty token map */
+  const std::vector<StorageToken*> no_tokens_;
 
   /*!
    * \brief Get the necessary token.
@@ -146,6 +150,11 @@ class StorageAllocaBaseVisitor : public 
transform::DeviceAwareExprVisitor {
    * \return The corresponding token.
    */
   const std::vector<StorageToken*>& GetToken(const Expr& expr) {
+    this->VisitExpr(expr);
+    // Functions don't require data storage, represented by the empty token
+    if (expr->checked_type().as<FuncTypeNode>()) {
+      return no_tokens_;
+    }
     // See through on_device calls.
     Expr real_expr = IgnoreOnDevice(expr);
     this->VisitExpr(real_expr);
@@ -159,8 +168,9 @@ class StorageAllocaBaseVisitor : public 
transform::DeviceAwareExprVisitor {
    * \brief Allocates (or reuses if \p can_realloc is true) a storage token 
for holding
    * the result of evaluating \p op.
    */
-  void CreateToken(const ExprNode* op, bool can_realloc) {
-    return CreateTokenOnDevice(op, GetInScopeDeviceType(GetRef<Expr>(op)), 
can_realloc);
+  void CreateToken(const ExprNode* expr_node, bool can_realloc) {
+    return CreateTokenOnDevice(expr_node, 
GetInScopeDeviceType(GetRef<Expr>(expr_node)),
+                               can_realloc);
   }
 
   /*!
@@ -203,12 +213,12 @@ class StorageAllocaInit : protected 
StorageAllocaBaseVisitor {
 
   using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_;
 
-  void DeviceAwareVisitExpr_(const CallNode* op) final {
+  void DeviceAwareVisitExpr_(const CallNode* call_node) final {
     // create token for the call node.
-    CreateToken(op, true);
+    CreateToken(call_node, true);
 
     // for each input, visit argument token.
-    for (Expr arg : op->args) {
+    for (Expr arg : call_node->args) {
       for (StorageToken* tok : GetToken(arg)) {
         tok->ref_counter += 1;
       }
@@ -273,7 +283,6 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
                  << "expressions are assigned with virtual device types. 
Either all "
                     "or none of the expressions are expected to be annotated.";
     }
-
     return backend::StaticMemoryPlan(smap);
   }
 
@@ -320,10 +329,13 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
   using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_;
 
   // The call map
-  void DeviceAwareVisitExpr_(const CallNode* op) final {
+  void DeviceAwareVisitExpr_(const CallNode* call_node) final {
     std::vector<StorageToken*> args;
     // for each input, visit argument token.
-    for (Expr arg : op->args) {
+
+    for (const Expr& arg : call_node->args) {
+      // Note: GetToken skips GlobalVars and handles tuples properly, so we 
don't need to treat
+      // call_lowered specially.
       for (StorageToken* tok : GetToken(arg)) {
         args.push_back(tok);
       }
@@ -337,20 +349,17 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
     //
     // TODO(tvm-team) Update checks of flat memory enablement when we support
     // opaque-nd memory planning to skip this path.
-    if (IsReshape(op)) {
-      // TODO(@electriclilies, jroesch): This check is failing because the 
size of args is 3
-      // I can't figure out where the extra args are coming from, I assume it 
must be related
-      // to the relay_attrs field we added to the TIRCallArgs, but I don't 
know where / how
-      // that's happening...
+
+    if (IsReshape(call_node)) {
       ICHECK_EQ(args.size(), 1U);
-      ReuseInputToken(op, args[0]);
+      ReuseInputToken(call_node, args[0]);
     } else {
       // create token for the call node.
-      CreateToken(op, true);
+      CreateToken(call_node, true);
     }
 
     // check if there is orphaned output that can be released immediately.
-    for (StorageToken* tok : token_map_.at(op)) {
+    for (StorageToken* tok : token_map_.at(call_node)) {
       CheckForRelease(tok);
     }
     for (StorageToken* tok : args) {
@@ -376,12 +385,11 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
       return fn->HasNonzeroAttr(attr::kReshapeOnly);
     }
 
-    if (call->attrs.defined()) {
-      if (auto tir_call_attrs = call->attrs.as<TIRCallAttrs>()) {
-        Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
-        return metadata.count(attr::kReshapeOnly) &&
-               (Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 
1);
-      }
+    if (call->op == CallLoweredOp()) {
+      CallLoweredProps call_lowered_props = GetCallLoweredProps(call);
+      Map<String, ObjectRef> metadata = call_lowered_props.attrs.metadata;
+      return metadata.count(attr::kReshapeOnly) &&
+             (Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 
1);
     }
 
     return false;
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 13b8556..4835d76 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -25,6 +25,7 @@
 #include <tvm/driver/driver_api.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
 #include <tvm/relay/attrs/debug.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/feature.h>
@@ -37,6 +38,7 @@
 #include <tvm/target/compilation_config.h>
 
 #include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
 #include "../transforms/pass_utils.h"
 #include "te_compiler.h"
 
@@ -682,82 +684,94 @@ class Interpreter : public ExprFunctor<ObjectRef(const 
Expr& n)>,
   }
 
   ObjectRef VisitExpr_(const CallNode* call_node) final {
-    std::vector<ObjectRef> args;
-    for (auto arg : call_node->args) {
-      args.push_back(Eval(arg));
-    }
+    if (call_node->op == CallLoweredOp()) {  // Special case: Call a lowered 
TIR function.
+      CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
 
-    if (call_node->op == OnDeviceOp()) {
-      // Special case: The call 'on_device(expr)' denotes that expr should be 
executed on
-      // a particular device. We can ignore this during interpretation.
-      ICHECK_EQ(call_node->args.size(), 1UL);
-      return args[0];
-    }
+      // Evaluate only function args
+      std::vector<ObjectRef> args;
+      for (auto arg : call_lowered_props.arguments) {
+        args.push_back(Eval(arg));
+      }
 
-    // We should not find calls to operators after running fusion and lowering.
-    if (const OpNode* op_node = call_node->op.as<OpNode>()) {
-      LOG(FATAL) << "found " << op_node->name
-                 << "; operators should have been removed by previous passes; 
try "
-                    "fusing and lowering";
-    }
+      // TODO(mbs): Make calling convention first-class in Relay.
+      Array<GlobalVar> all_prim_fn_vars;
+      if (call_lowered_props.attrs.metadata.count("all_prim_fn_vars")) {
+        all_prim_fn_vars =
+            
Downcast<Array<GlobalVar>>(call_lowered_props.attrs.metadata.at("all_prim_fn_vars"));
+      }
+      GlobalVar prim_shape_fn_var;
+      if (call_lowered_props.attrs.metadata.count("prim_shape_fn_var")) {
+        prim_shape_fn_var =
+            
Downcast<GlobalVar>(call_lowered_props.attrs.metadata.at("prim_shape_fn_var"));
+      }
+      Array<GlobalVar> all_prim_shape_fn_vars;
+      if (call_lowered_props.attrs.metadata.count("all_prim_shape_fn_vars")) {
+        all_prim_shape_fn_vars = Downcast<Array<GlobalVar>>(
+            call_lowered_props.attrs.metadata.at("all_prim_shape_fn_vars"));
+      }
+      Array<Integer> prim_shape_fn_states;
+      if (call_lowered_props.attrs.metadata.count("prim_shape_fn_states")) {
+        prim_shape_fn_states =
+            
Downcast<Array<Integer>>(call_lowered_props.attrs.metadata.at("prim_shape_fn_states"));
+      }
 
-    if (const ConstructorNode* con = call_node->op.as<ConstructorNode>()) {
-      // Special case: ADT constructor
-      return ConstructorValue(con->tag, args, GetRef<Constructor>(con));
-    }
+      size_t num_shape_inputs = 0;
+      if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_inputs")) 
{
+        num_shape_inputs = static_cast<size_t>(
+            
Downcast<Integer>(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_inputs"))
+                ->value);
+      }
+      size_t num_shape_outputs = 0;
+      if 
(call_lowered_props.attrs.metadata.count("prim_shape_fn_num_outputs")) {
+        num_shape_outputs = static_cast<size_t>(
+            
Downcast<Integer>(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_outputs"))
+                ->value);
+      }
+      ICHECK(config_->optional_homogeneous_target.defined());
+      return InvokePrimitiveOp(call_lowered_props.lowered_func, 
all_prim_fn_vars,
+                               config_->optional_homogeneous_target, 
prim_shape_fn_var,
+                               all_prim_shape_fn_vars, prim_shape_fn_states, 
num_shape_inputs,
+                               num_shape_outputs, 
config_->host_se_scope->target, args);
+    } else {  // All other calls
+      // Evaluate all arguments
+      std::vector<ObjectRef> args;
+      for (auto arg : call_node->args) {
+        args.push_back(Eval(arg));
+      }
 
-    if (const GlobalVarNode* gvn = call_node->op.as<GlobalVarNode>()) {
-      if (const TIRCallAttrs* attrs = call_node->attrs.as<TIRCallAttrs>()) {
-        // Special case: Call a lowered TIR function.
-        // TODO(mbs): Make calling convention first-class in Relay.
-        Array<GlobalVar> all_prim_fn_vars;
-        if (attrs->metadata.count("all_prim_fn_vars")) {
-          all_prim_fn_vars = 
Downcast<Array<GlobalVar>>(attrs->metadata.at("all_prim_fn_vars"));
-        }
-        GlobalVar prim_shape_fn_var;
-        if (attrs->metadata.count("prim_shape_fn_var")) {
-          prim_shape_fn_var = 
Downcast<GlobalVar>(attrs->metadata.at("prim_shape_fn_var"));
-        }
-        Array<GlobalVar> all_prim_shape_fn_vars;
-        if (attrs->metadata.count("all_prim_shape_fn_vars")) {
-          all_prim_shape_fn_vars =
-              
Downcast<Array<GlobalVar>>(attrs->metadata.at("all_prim_shape_fn_vars"));
-        }
-        Array<Integer> prim_shape_fn_states;
-        if (attrs->metadata.count("prim_shape_fn_states")) {
-          prim_shape_fn_states =
-              
Downcast<Array<Integer>>(attrs->metadata.at("prim_shape_fn_states"));
-        }
-        size_t num_shape_inputs = 0;
-        if (attrs->metadata.count("prim_shape_fn_num_inputs")) {
-          num_shape_inputs = static_cast<size_t>(
-              
Downcast<Integer>(attrs->metadata.at("prim_shape_fn_num_inputs"))->value);
-        }
-        size_t num_shape_outputs = 0;
-        if (attrs->metadata.count("prim_shape_fn_num_outputs")) {
-          num_shape_outputs = static_cast<size_t>(
-              
Downcast<Integer>(attrs->metadata.at("prim_shape_fn_num_outputs"))->value);
-        }
+      if (call_node->op == OnDeviceOp()) {
+        // Special case: The call 'on_device(expr)' denotes that expr should 
be executed on
+        // a particular device. We can ignore this during interpretation.
+        ICHECK_EQ(call_node->args.size(), 1UL);
+        return args[0];
+      }
+      if (const ConstructorNode* con = call_node->op.as<ConstructorNode>()) {
+        // Special case: ADT constructor
 
-        ICHECK(config_->optional_homogeneous_target.defined());
-        return InvokePrimitiveOp(GetRef<GlobalVar>(gvn), all_prim_fn_vars,
-                                 config_->optional_homogeneous_target, 
prim_shape_fn_var,
-                                 all_prim_shape_fn_vars, prim_shape_fn_states, 
num_shape_inputs,
-                                 num_shape_outputs, 
config_->host_se_scope->target, args);
+        return ConstructorValue(con->tag, args, GetRef<Constructor>(con));
       }
-    }
 
-    // Now we just evaluate and expect to find a closure.
-    ObjectRef fn_val = Eval(call_node->op);
-    if (const InterpreterClosureObj* closure_node = 
fn_val.as<InterpreterClosureObj>()) {
-      auto closure = GetRef<InterpreterClosure>(closure_node);
-      return Invoke(closure, args);
-    } else if (const RecClosureObj* closure_node = fn_val.as<RecClosureObj>()) 
{
-      return Invoke(closure_node->clos, args, closure_node->bind);
-    } else {
-      LOG(FATAL) << "internal error: type error, expected function value in 
the call "
-                 << "position";
-      return ObjectRef();
+      if (const OpNode* op_node = call_node->op.as<OpNode>()) {
+        // Except for call_lowered and on_device, we should not find calls to 
operators after
+        // running fusion and lowering.
+        LOG(FATAL) << "found " << op_node->name
+                   << "; operators should have been removed by previous 
passes; try "
+                      "fusing and lowering";
+      }
+
+      // Now we just evaluate and expect to find a closure.
+      // TODO(@electriclilies): How should call_lowered behave with closures?
+      ObjectRef fn_val = Eval(call_node->op);
+      if (const InterpreterClosureObj* closure_node = 
fn_val.as<InterpreterClosureObj>()) {
+        auto closure = GetRef<InterpreterClosure>(closure_node);
+        return Invoke(closure, args);
+      } else if (const RecClosureObj* closure_node = 
fn_val.as<RecClosureObj>()) {
+        return Invoke(closure_node->clos, args, closure_node->bind);
+      } else {
+        LOG(FATAL) << "internal error: type error, expected function value in 
the call "
+                   << "position";
+        return ObjectRef();
+      }
     }
   }
 
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 163bb9f..915fc22 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -24,6 +24,7 @@
 #include <tvm/ir/function.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
 #include <tvm/relay/attrs/device_copy.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
@@ -43,6 +44,7 @@
 #include <vector>
 
 #include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
 #include "../transforms/device_aware_visitors.h"
 #include "./te_compiler_cache.h"
 #include "./utils.h"
@@ -460,7 +462,8 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
    * to the TIR implementation, and attributes to attach to the call to 
identify it as
    * a TIR call.
    */
-  std::pair<GlobalVar, Attrs> LowerFunction(Function func, Target target) {
+  Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type> 
type_args, Span span,
+                       Target target) {
     if (func->GetAttr<String>(attr::kCompiler).defined()) {
       // BYOC flow.
       CCacheKey key = CCacheKey(func, target);
@@ -468,6 +471,7 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
       ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
                                  << ext_func->prim_fn_var->name_hint;
 
+      // TODO(@areusch, @jroesch): this metadata is for AOT, this should be 
our interface for AOT
       Map<GlobalVar, tir::PrimFunc> prim_fns;
       relay::Function func_with_metadata = func;
       func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
ext_func->prim_fn_var);
@@ -478,87 +482,91 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
       // act when we process a function.
       this->process_fn_(func_with_metadata);
 
-      // TODO(mbs): Need TIRCallAttrs or equiv so targets know this is an 
extern.
       // TODO(mbs): Dynamic shapes?
-      return {ext_func->prim_fn_var, Attrs()};
-    }
+      // TODO(@mbs, electriclilies): Make extern functions explicit
+      return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args, 
span);
 
-    // Non-External Relay Function
-    VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" 
<< PrettyPrint(func);
-    CCacheKey key = CCacheKey(func, target);
-    CachedFunc lowered_func = compiler_->Lower(key, module_name_);
-    VLOG(1) << "lowered primitive bound to '" << 
PrettyPrint(lowered_func->prim_fn_var) << "'";
-
-    // Collect all the lowered functions produced for this primitive function.
-    Map<GlobalVar, tir::PrimFunc> prim_fns;
-    Array<GlobalVar> all_prim_fn_vars;
-    for (auto prim_fn : lowered_func->funcs->functions) {
-      CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
-      prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
-      all_prim_fn_vars.push_back(prim_fn.first);
-      VLOG(1) << "lowered primitive includes bindings for '" << 
PrettyPrint(prim_fn.first) << "'";
-    }
+    } else {
+      // Non-External Relay Function
+      VLOG(1) << "lowering to target '" << target->str() << "' for 
primitive:\n"
+              << PrettyPrint(func);
+      CCacheKey key = CCacheKey(func, target);
+      CachedFunc lowered_func = compiler_->Lower(key, module_name_);
+      VLOG(1) << "lowered primitive bound to '" << 
PrettyPrint(lowered_func->prim_fn_var) << "'";
 
-    // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our 
interface for AOT
-    relay::Function func_with_metadata = func;
-    func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
lowered_func->prim_fn_var);
-    func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
-    func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
lowered_func->target);
+      // Collect all the lowered functions produced for this primitive 
function.
+      Map<GlobalVar, tir::PrimFunc> prim_fns;
+      Array<GlobalVar> all_prim_fn_vars;
+      for (auto prim_fn : lowered_func->funcs->functions) {
+        CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+        prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
+        all_prim_fn_vars.push_back(prim_fn.first);
+        VLOG(1) << "lowered primitive includes bindings for '" << 
PrettyPrint(prim_fn.first) << "'";
+      }
 
-    // Provide a callback hook which allows one-level up code generators to
-    // act when we process a function.
-    this->process_fn_(func_with_metadata);
+      // TODO(@areusch, @jroesch): this metadata is for AOT, this should be 
our interface for AOT
+      relay::Function func_with_metadata = func;
+      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
lowered_func->prim_fn_var);
+      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", 
prim_fns);
+      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
lowered_func->target);
 
-    auto tir_call_attrs = make_object<TIRCallAttrs>();
-    if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
-      tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
-    }
+      // Provide a callback hook which allows one-level up code generators to
+      // act when we process a function.
+      this->process_fn_(func_with_metadata);
 
-    auto device_copy = IsDeviceCopy(func);
-    if (std::get<0>(device_copy)) {
-      // Record that device copy source and destination devices so the device 
planner can
-      // still follow along.
-      auto source_device = std::get<1>(device_copy);
-      auto dst_device = std::get<2>(device_copy);
-      tir_call_attrs->metadata.Set("source_device", 
tvm::Integer(source_device));
-      tir_call_attrs->metadata.Set("dst_device", tvm::Integer(dst_device));
-    }
+      auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+      if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
+        call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
+      }
 
-    tir_call_attrs->metadata.Set("relay_attrs", func->attrs);
-    tir_call_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
-
-    if (IsDynamic(func->ret_type)) {
-      // Also lower the dynamic shape function.
-      // Shape function keys use the underlying primitive function as their 
'function',
-      // but the generic 'cpu' target as the target since all shape functions 
run
-      // on the host cpu irrespective of where the primitive runs.
-      // TODO(mbs): Cleanup target handling.
-      Target shape_target("llvm");
-      VLOG(1) << "lowering to target '" << shape_target->str()
-              << "' for dynamic shape function for primitive";
-      CCacheKey shape_key(func, shape_target);
-      CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
-      // Capture the shape function's global var and parameters 'states' in 
call
-      // annotations so calling convention can be recovered.
-      // TODO(mbs): Capture all this as part of a 'call into TIR' construct 
once available.
-      // The way the shape function calling convention is derived and passed 
to call sites
-      // via the 'parameter states' could be improved.
-      tir_call_attrs->metadata.Set("prim_shape_fn_var", 
lowered_shape_func->prim_fn_var);
-      tir_call_attrs->metadata.Set("prim_shape_fn_states",
-                                   
lowered_shape_func->shape_func_param_states);
-      tir_call_attrs->metadata.Set("prim_shape_fn_num_inputs",
-                                   
Integer(static_cast<int>(lowered_shape_func->inputs.size())));
-      tir_call_attrs->metadata.Set("prim_shape_fn_num_outputs",
-                                   
Integer(static_cast<int>(lowered_shape_func->outputs.size())));
-      Array<GlobalVar> all_prim_shape_fn_vars;
-      for (auto prim_shape_fn : lowered_shape_func->funcs->functions) {
-        CHECK(prim_shape_fn.second.as<tir::PrimFuncNode>()) << "must be a prim 
fn";
-        all_prim_shape_fn_vars.push_back(prim_shape_fn.first);
+      auto device_copy = IsDeviceCopy(func);
+      if (std::get<0>(device_copy)) {
+        // Record that device copy source and destination devices so the 
device planner can
+        // still follow along.
+        auto source_device = std::get<1>(device_copy);
+        auto dst_device = std::get<2>(device_copy);
+        call_lowered_attrs->metadata.Set("source_device", 
tvm::Integer(source_device));
+        call_lowered_attrs->metadata.Set("dst_device", 
tvm::Integer(dst_device));
       }
-      tir_call_attrs->metadata.Set("all_prim_shape_fn_vars", 
all_prim_shape_fn_vars);
-    }
 
-    return {lowered_func->prim_fn_var, Attrs(tir_call_attrs)};
+      call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
+      call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
+
+      if (IsDynamic(func->ret_type)) {
+        // Also lower the dynamic shape function.
+        // Shape function keys use the underlying primitive function as their 
'function',
+        // but the generic 'cpu' target as the target since all shape 
functions run
+        // on the host cpu irrespective of where the primitive runs.
+        // TODO(mbs): Cleanup target handling.
+        Target shape_target("llvm");
+        VLOG(1) << "lowering to target '" << shape_target->str()
+                << "' for dynamic shape function for primitive";
+        CCacheKey shape_key(func, shape_target);
+        CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
+        // Capture the shape function's global var and parameters 'states' in 
call
+        // annotations so calling convention can be recovered.
+        // TODO(mbs): Capture all this as part of a 'call into TIR' construct 
once available.
+        // The way the shape function calling convention is derived and passed 
to call sites
+        // via the 'parameter states' could be improved.
+        call_lowered_attrs->metadata.Set("prim_shape_fn_var", 
lowered_shape_func->prim_fn_var);
+        call_lowered_attrs->metadata.Set("prim_shape_fn_states",
+                                         
lowered_shape_func->shape_func_param_states);
+        call_lowered_attrs->metadata.Set(
+            "prim_shape_fn_num_inputs",
+            Integer(static_cast<int>(lowered_shape_func->inputs.size())));
+        call_lowered_attrs->metadata.Set(
+            "prim_shape_fn_num_outputs",
+            Integer(static_cast<int>(lowered_shape_func->outputs.size())));
+        Array<GlobalVar> all_prim_shape_fn_vars;
+        for (auto prim_shape_fn : lowered_shape_func->funcs->functions) {
+          CHECK(prim_shape_fn.second.as<tir::PrimFuncNode>()) << "must be a 
prim fn";
+          all_prim_shape_fn_vars.push_back(prim_shape_fn.first);
+        }
+        call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", 
all_prim_shape_fn_vars);
+      }
+      return CallLowered(lowered_func->prim_fn_var, visited_args, 
Attrs(call_lowered_attrs),
+                         type_args, span);
+    }
   }
 
   std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) 
final {
@@ -593,6 +601,9 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
   }
 
   Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
+    // Passes before lowering might insert a call_lowered to call a function 
that has already
+    // been lowered. Therefore we might see call_lowered ops here, but we 
don't need to do anything
+    // because ResolveToPrimitive returns null for all calls where the 
call_node->op is an OpNode
     Call call = GetRef<Call>(call_node);
 
     // Look for (indirect) calls to primitives.
@@ -628,15 +639,13 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
       // TODO(mbs): Replace device_type with target so this lookup is 
unnecessary.
       target = GetTargetFromInteger(device_type, targets_);
     }
-
+    Array<Expr> visited_args;
+    for (const auto& arg : call_node->args) {
+      visited_args.push_back(VisitExpr(arg));
+    }
     // Lower the primitive function for that target.
     Function func = Downcast<Function>(prim_func);
-    std::pair<GlobalVar, Attrs> pair = LowerFunction(func, target);
-
-    // Replace with direct call to lowered primitive, and attach annotations 
to record calling
-    // convention.
-    // =====> in new call_lowered form
-    return Call(pair.first, args, pair.second);
+    return MakeLoweredCall(func, visited_args, call_node->type_args, 
call_node->span, target);
   }
 
   IRModule module_;
diff --git a/src/relay/op/call/call.cc b/src/relay/op/call/call.cc
new file mode 100644
index 0000000..9485b72
--- /dev/null
+++ b/src/relay/op/call/call.cc
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.cc
+ * \brief Operators for calling lowered functions.
+ */
+
+#include "./call.h"
+
+#include <tvm/relay/attrs/call.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include "../../transforms/infer_layout_utils.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(CallLoweredAttrs);
+
+// call_lowered
+bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                    const TypeReporter& reporter) {
+  // Types = [func, call_args, ret_type]
+  if (types.size() != 3u) {
+    return false;
+  }
+  const auto* func_type = types[0].as<FuncTypeNode>();
+  if (!func_type) {
+    return false;
+  }
+
+  const auto* tuple_type_node = types[1].as<TupleTypeNode>();
+  if (!tuple_type_node) {
+    return false;
+  }
+
+  // Constraint to ensure function arguments are the same type as the inputs 
to the function (modulo
+  // the Tuple wrapper)
+  reporter->Assign(GetRef<TupleType>(tuple_type_node), 
TupleType(func_type->arg_types, {}));
+  // Constraint to ensure the output of call_lowered is the same as the 
function's return type
+  reporter->Assign(types[2], func_type->ret_type);
+  return true;
+}
+
+const Op& CallLoweredOp() { return Op::Get("call_lowered"); }
+
+Expr CallLowered(Expr func, Array<Expr> inputs, Attrs attrs, Array<Type> 
type_args, Span span) {
+  // Right now, call_lowered only supports func being a global var pointing to 
the lowered
+  // function.
+  ICHECK(func.as<GlobalVarNode>())
+      << "Function to call should be GlobalVarNode, but got " << 
func->GetTypeKey();
+  ICHECK(attrs.as<CallLoweredAttrs>())
+      << "Expected attributes to be CallLoweredAttrs, but got " << 
attrs->GetTypeKey();
+  return Call(CallLoweredOp(), {std::move(func), Tuple(std::move(inputs))}, 
std::move(attrs),
+              std::move(type_args), std::move(span));
+}
+
+TVM_REGISTER_GLOBAL("relay.op.call_lowered")
+    .set_body_typed([](Expr func, Array<Expr> inputs, Attrs attrs, Array<Type> 
type_args,
+                       Span span) {
+      const TupleNode* tuple_node = inputs.as<TupleNode>();
+      return CallLowered(func, tuple_node->fields, attrs, type_args, span);
+    });
+
+RELAY_REGISTER_OP("call_lowered")
+    .describe(R"code(Invoke an operation compiled by TVM.)code" 
TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .set_attrs_type<CallLoweredAttrs>()
+    .add_argument("func", "Function", "The lowered function to call.")
+    .add_argument("call_args", "Tuple", "The input tensors.")
+    .add_type_rel("CallLoweredRel", CallLoweredRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", 
ElemwiseArbitraryLayout);
+
+CallLoweredProps GetCallLoweredProps(const CallNode* call_node) {
+  ICHECK(call_node->op == CallLoweredOp())
+      << "GetCallLoweredProps expects the op to be call_lowered. ";
+  ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 
arguments. ";
+  const auto* function = call_node->args[0].as<GlobalVarNode>();
+  ICHECK(function) << "Expected first arg to call_lowered to be a GlobalVar. ";
+
+  const auto* tuple_args = call_node->args[1].as<TupleNode>();
+  ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple. ";
+
+  ICHECK(call_node->attrs.defined()) << "Attributes for call_lowered should be 
defined!";
+  const auto* attrs = call_node->attrs.as<CallLoweredAttrs>();
+  ICHECK(attrs) << "Expected call_lowered op to have CallLoweredAttrs, but 
found "
+                << call_node->attrs->GetTypeKey();
+  return CallLoweredProps{std::move(GetRef<GlobalVar>(function)), 
std::move(tuple_args->fields),
+                          std::move(*attrs)};
+}
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/op/call/call.h b/src/relay/op/call/call.h
new file mode 100644
index 0000000..381be67
--- /dev/null
+++ b/src/relay/op/call/call.h
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/call/call.h
+ * \brief Operators for calling lowered functions.
+ */
+#ifndef TVM_RELAY_OP_CALL_CALL_H_
+#define TVM_RELAY_OP_CALL_CALL_H_
+
+#include <tvm/relay/attrs/call.h>
+#include <tvm/relay/expr.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Helper to construct a Relay call with the call_lowered op.
+ * \param func Lowered function to call with call_lowered.
+ * \param inputs Arguments to be passed to the function.
+ * \param attrs Function attributes, should be TIRCallAttrs.
+ * \param type_args Type arguments for the call.
+ * \param span TVM span for propogating debugging info.
+ * \return
+ */
+Expr CallLowered(Expr func, Array<Expr> inputs, Attrs attrs, Array<Type> 
type_args, Span span);
+
+/*!
+ * \brief Returns the Relay call_lowered op. Use this helper to avoid 
extraneous calls to
+ * Registry::Get.
+ */
+const Op& CallLoweredOp();
+
+/*!
+ * \brief Lowered function and the arguments to call it with.
+ */
+struct CallLoweredProps {
+  /*! \brief Global variable pointing to the lowered function. */
+  GlobalVar lowered_func;
+  /*! \brief Array of the arguments to call lowered_func with. */
+  Array<Expr> arguments;
+  /*! \brief Arguments from the call_lowered op. */
+  CallLoweredAttrs attrs;
+};
+
+/*!
+ * \brief Helper to extract the lowered function and its arguments from 
Call("call_lowered", ...).
+ * Will fail if called on a Call whose op is not "call_lowered" \param 
call_node CallNode that we
+ * want to get the function and its arguments from.
+ */
+CallLoweredProps GetCallLoweredProps(const CallNode* call_node);
+
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_OP_CALL_CALL_H_
diff --git a/src/relay/op/memory/device_copy.cc 
b/src/relay/op/memory/device_copy.cc
index dce89aa..9106b95 100644
--- a/src/relay/op/memory/device_copy.cc
+++ b/src/relay/op/memory/device_copy.cc
@@ -24,6 +24,7 @@
 
 #include "./device_copy.h"
 
+#include <tvm/relay/attrs/call.h>
 #include <tvm/relay/attrs/device_copy.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/op.h>
@@ -31,6 +32,8 @@
 #include <tvm/topi/elemwise.h>
 
 #include "../../transforms/infer_layout_utils.h"
+#include "../annotation/annotation.h"
+#include "../call/call.h"
 #include "../type_relations.h"
 
 namespace tvm {
@@ -86,6 +89,7 @@ on different devices.
                              return {topi::identity(inputs[0])};
                            });
 
+// Get device copy props for original device copy op
 DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) {
   if (call_node->op == DeviceCopyOp()) {
     ICHECK_EQ(call_node->args.size(), 1) << "device_copy expects one argument";
@@ -103,6 +107,19 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* 
call_node) {
     } else {
       return {call_node->args[0], src_dev_type, dst_dev_type};
     }
+  } else if (call_node->op == CallLoweredOp()) {
+    /* Get device props for a TIR function */
+    CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+
+    if (call_lowered_props.attrs.metadata.count("source_device") == 1 &&
+        call_lowered_props.attrs.metadata.count("dst_device") == 1) {
+      ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of 
arity 1";
+      return {call_lowered_props.lowered_func,
+              static_cast<DLDeviceType>(
+                  
Downcast<Integer>(call_lowered_props.attrs.metadata["source_device"])->value),
+              static_cast<DLDeviceType>(
+                  
Downcast<Integer>(call_lowered_props.attrs.metadata["dst_device"])->value)};
+    }
   }
   return {};
 }
diff --git a/src/relay/op/vm/vm.h b/src/relay/op/vm/vm.h
index 802c810..68d25b0 100644
--- a/src/relay/op/vm/vm.h
+++ b/src/relay/op/vm/vm.h
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_OP_VM_VM_H_
 #define TVM_RELAY_OP_VM_VM_H_
 
-#include "tvm/relay/expr.h"
+#include <tvm/relay/expr.h>
 
 namespace tvm {
 namespace relay {
diff --git a/src/relay/transforms/device_domains.cc 
b/src/relay/transforms/device_domains.cc
index 1578485..b9fa049 100644
--- a/src/relay/transforms/device_domains.cc
+++ b/src/relay/transforms/device_domains.cc
@@ -24,9 +24,11 @@
 
 #include "./device_domains.h"
 
+#include <tvm/relay/attrs/call.h>
 #include <tvm/relay/attrs/memory.h>
 
 #include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
 #include "../op/memory/device_copy.h"
 
 namespace tvm {
@@ -47,20 +49,19 @@ constexpr size_t mix(size_t h1, size_t h2) {
  * See te_compiler.cc for where this rewriting occurs.
  */
 DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
-  auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
-  if (tir_call_attrs == nullptr) {
-    return {};
-  }
-  if (tir_call_attrs->metadata.count("source_device") != 1 ||
-      tir_call_attrs->metadata.count("dst_device") != 1) {
-    return {};
+  if (call_node->op == CallLoweredOp()) {
+    CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+    if (call_lowered_props.attrs.metadata.count("source_device") == 1 &&
+        call_lowered_props.attrs.metadata.count("dst_device") == 1) {
+      ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of 
arity 1";
+      return {call_lowered_props.arguments[0],
+              static_cast<DLDeviceType>(
+                  
Downcast<Integer>(call_lowered_props.attrs.metadata["source_device"])->value),
+              static_cast<DLDeviceType>(
+                  
Downcast<Integer>(call_lowered_props.attrs.metadata["dst_device"])->value)};
+    }
   }
-  ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1";
-  return {
-      call_node->args[0],
-      static_cast<DLDeviceType>(
-          Downcast<Integer>(tir_call_attrs->metadata["source_device"])->value),
-      
static_cast<DLDeviceType>(Downcast<Integer>(tir_call_attrs->metadata["dst_device"])->value)};
+  return {};
 }
 
 }  // namespace
@@ -319,8 +320,12 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& 
call) {
       args_and_result.emplace_back(param_domain);
     }
     args_and_result.emplace_back(result_domain);
+  } else if (call->op == CallLoweredOp()) {
+    CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get());
+    return DomainFor(call_lowered_props.lowered_func);
   } else {
-    // Defer to normal case where op can be an arbitrary expression.
+    // We still need to handle the case where the function / op is not lowered
+    // because the device planner runs before and after lowering.
     return DomainFor(call->op);
   }
   auto domain = MakeDomain(std::move(args_and_result));
diff --git a/src/relay/transforms/memory_alloc.cc 
b/src/relay/transforms/memory_alloc.cc
index 81d704e..a328eaa 100644
--- a/src/relay/transforms/memory_alloc.cc
+++ b/src/relay/transforms/memory_alloc.cc
@@ -26,6 +26,7 @@
 #include <tvm/node/structural_hash.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/call.h>
 #include <tvm/relay/attrs/device_copy.h>
 #include <tvm/relay/attrs/memory.h>
 #include <tvm/relay/expr.h>
@@ -44,6 +45,7 @@
 #include "../backend/te_compiler.h"
 #include "../backend/te_compiler_cache.h"
 #include "../op/annotation/annotation.h"
+#include "../op/call/call.h"
 #include "../op/memory/device_copy.h"
 #include "../op/memory/memory.h"
 #include "../op/vm/vm.h"
@@ -74,12 +76,11 @@ bool IsReshapeOnly(const Expr& expr) {
     return func->HasNonzeroAttr(attr::kReshapeOnly);
   }
   if (const CallNode* call = expr.as<CallNode>()) {
-    if (call->attrs.defined()) {
-      if (auto tir_call_attrs = call->attrs.as<TIRCallAttrs>()) {
-        Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
-        return metadata.count(attr::kReshapeOnly) &&
-               (Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 
1);
-      }
+    if (call->op == CallLoweredOp()) {
+      CallLoweredProps call_lowered_props = GetCallLoweredProps(call);
+      Map<String, ObjectRef> metadata = call_lowered_props.attrs.metadata;
+      return metadata.count(attr::kReshapeOnly) &&
+             (Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 
1);
     }
   }
   return false;
@@ -377,7 +378,8 @@ class DialectRewriter : public 
transform::DeviceAwareExprMutator {
     }
 
     Tuple tuple_outs(outs);
-    auto invoke = OnDevice(InvokeTVMOp(func, ins, tuple_outs), 
dev.device_type, /*is_fixed=*/true);
+    auto call = InvokeTVMOp(func, ins, tuple_outs);
+    auto invoke = OnDevice(call, dev.device_type, /*is_fixed=*/true);
     scope->Push(invoke);
     return ToTupleType(ret_type,
                        std::vector<Expr>(tuple_outs->fields.begin(), 
tuple_outs->fields.end()));

Reply via email to