electriclilies commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759706150



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -481,206 +557,217 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
    */
   Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type> 
type_args, Span span,
                        Target target) {
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      // BYOC flow.
-      CCacheKey key = CCacheKey(func, target);
-      CachedFunc ext_func = compiler_->Lower(key, module_name_);
-      ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
-                                 << ext_func->prim_fn_var->name_hint;
-
-      // TODO(@areusch, @jroesch): this metadata is for AOT, this should be 
our interface for AOT
-      Map<GlobalVar, tir::PrimFunc> prim_fns;
-      relay::Function func_with_metadata = func;
-      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
ext_func->prim_fn_var);
-      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", 
prim_fns);
-      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
ext_func->target);
-
-      // Provide a callback hook which allows one-level up code generators to
-      // act when we process a function.
-      this->process_fn_(func_with_metadata);
-
-      // TODO(mbs): Dynamic shapes?
-      // TODO(@mbs, electriclilies): Make extern functions explicit
-      return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args, 
span);
-
-    } else {
-      // Non-External Relay Function
-      CCacheKey key = CCacheKey(func, target);
-      CachedFunc lowered_func = compiler_->Lower(key, module_name_);
-
-      // Collect all the lowered functions produced for this primitive 
function.
-      Map<GlobalVar, tir::PrimFunc> prim_fns;
-      Array<GlobalVar> all_prim_fn_vars;
-      for (auto prim_fn : lowered_func->funcs->functions) {
-        CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
-        prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
-        all_prim_fn_vars.push_back(prim_fn.first);
-      }
-
-      // TODO(@areusch, @jroesch): this metadata is for AOT, this should be 
our interface for AOT
-      relay::Function func_with_metadata = func;
-      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
lowered_func->prim_fn_var);
-      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", 
prim_fns);
-      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
lowered_func->target);
-
-      // Provide a callback hook which allows one-level up code generators to
-      // act when we process a function.
-      this->process_fn_(func_with_metadata);
-
-      auto call_lowered_attrs = make_object<CallLoweredAttrs>();
-      if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
-        call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
-      }
-
-      DeviceCopyProps props = GetDeviceCopyProps(func);
-      if (props.body.defined()) {
-        // Record the device copy source and destination scopes so the device 
planner can
-        // still follow along even after lowering.
-        call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope);
-        call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope);
+    CCacheKey key = CCacheKey(func, target);
+    CachedFunc cfunc = compiler_->Lower(key, module_name_);
+    ICHECK(cfunc.defined());
+
+    auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
+
+    // Add some metadata on top of the *original function* and invoke the 
callback so it can
+    // be captured.
+    // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our 
interface for AOT
+    Map<GlobalVar, tir::PrimFunc> prim_fns;
+    Array<GlobalVar> all_prim_fn_vars;
+    for (const auto& kv : cfunc->funcs->functions) {
+      if (opt_compiler) {
+        // We expect just the original func but with just the ExternalSymbol 
attribute signalling

Review comment:
       typo: signalling -> signaling

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -951,22 +1039,87 @@ void UpdateFunctionMetadata(BaseFunc func,
   function_metadata.Set(prim_fn_var.value()->name_hint, fi);
 }
 
-IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn 
process_fn) {
-  TECompiler compiler;
-
-  auto updated_module = LowerTensorExpr(module_name, compiler, 
process_fn)(module);
-
-  backend::UpdateAutoSchedulerOpWeights(compiler);
+IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn 
process_fn,
+                 SEScope host_se_scope) {
+  TECompiler compiler(module);
+
+  // TODO(mbs): This is all unnecessarily convoluted. Better would be to 
accumulate the rewritten
+  // module as we go (including rewritten Functions, lowered primitives, and 
runtime modules
+  // generated by external toolchains), and use a pair of maps over vars and 
global vars
+  // to global vars to remember which functions have already been lowered.
+
+  // Lower all the callees in module:
+  //  - Functions tagged with "Compiler" are unchanged (checked by 
CreateFunctionPass)
+  //  - Functions tagged with "Primitive" are unchanged (checked by 
LowerTensorExprMutator)
+  //  - Called functions tagged with "Compiler" are copied into the compiler 
cache with a fresh
+  //    GlobalVar, and calls updated (sticking with regular Relay Call).
+  //  - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, 
and calls updated
+  //    (using call_lowered convention).
+  IRModule updated_module = LowerTensorExpr(module_name, compiler, 
std::move(process_fn),
+                                            std::move(host_se_scope))(module);
+
+  // The Functions tagged with "Compiler" are now residing in the cache ready 
to be
+  // compiled by LowerExternalFunctions. However we still need a record of 
them in the
+  // IRModule so that the various executors can see which function names need 
to be
+  // retrieved. They may, however, have been renamed.
+  compiler->AddExterns(updated_module);
+
+  // Add the lowered functions.
+  IRModule lowered_module = compiler->GetLoweredFunctions();
+  VLOG(1) << "capturing " << lowered_module->functions.size() << " new lowered 
functions";
+  for (const auto& kv : lowered_module->functions) {
+    if (updated_module->ContainGlobalVar(kv.first->name_hint)) {
+      LOG(FATAL) << "duplicate bindings for '" << kv.first->name_hint
+                 << "'. Existing is:" << std::endl
+                 << PrettyPrint(updated_module->Lookup(kv.first->name_hint)) 
<< std::endl
+                 << "while new is:" << std::endl
+                 << PrettyPrint(kv.second);
+    }
+    updated_module->Add(kv.first, kv.second);
+  }
 
-  // Copy the lowered functions into the return module
-  updated_module->Update(compiler->GetLoweredFunctions());
+  // Invoke external codegen for all Functions in the cache tagged with 
"Compiler", and
+  // annotate the module with the resulting runtime modules.
+  // TODO(mbs): runtime modules should be first class rather than attributes.
+  Array<runtime::Module> external_mods =
+      module->GetAttr<Array<runtime::Module>>("external_mods", 
Array<runtime::Module>()).value();
+  Array<runtime::Module> new_external_mods = 
compiler->LowerExternalFunctions();
+  VLOG(1) << "capturing " << external_mods.size() << " existing and " << 
new_external_mods.size()
+          << " new external modules";
+  for (const auto& mod : new_external_mods) {
+    external_mods.push_back(mod);  // copy-on-write.
+  }
 
-  // Annotate the module with C Device API context mapping, the external 
modules and function info
-  // this is until we have Target's annotated for the C Device API
+  // Annotate the module with C Device API context mapping (this is until we 
have Target's
+  // annotated for the C Device API)
   // TODO(Mousius) - Remove "device_contexts" as soon as we have the graph 
annotated properly with
   // Target's

Review comment:
       again Target's should be Targets or targets

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -951,22 +1039,87 @@ void UpdateFunctionMetadata(BaseFunc func,
   function_metadata.Set(prim_fn_var.value()->name_hint, fi);
 }
 
-IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn 
process_fn) {
-  TECompiler compiler;
-
-  auto updated_module = LowerTensorExpr(module_name, compiler, 
process_fn)(module);
-
-  backend::UpdateAutoSchedulerOpWeights(compiler);
+IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn 
process_fn,
+                 SEScope host_se_scope) {
+  TECompiler compiler(module);
+
+  // TODO(mbs): This is all unnecessarily convoluted. Better would be to 
accumulate the rewritten
+  // module as we go (including rewritten Functions, lowered primitives, and 
runtime modules
+  // generated by external toolchains), and use a pair of maps over vars and 
global vars
+  // to global vars to remember which functions have already been lowered.
+
+  // Lower all the callees in module:
+  //  - Functions tagged with "Compiler" are unchanged (checked by 
CreateFunctionPass)
+  //  - Functions tagged with "Primitive" are unchanged (checked by 
LowerTensorExprMutator)
+  //  - Called functions tagged with "Compiler" are copied into the compiler 
cache with a fresh
+  //    GlobalVar, and calls updated (sticking with regular Relay Call).
+  //  - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, 
and calls updated
+  //    (using call_lowered convention).
+  IRModule updated_module = LowerTensorExpr(module_name, compiler, 
std::move(process_fn),
+                                            std::move(host_se_scope))(module);
+
+  // The Functions tagged with "Compiler" are now residing in the cache ready 
to be
+  // compiled by LowerExternalFunctions. However we still need a record of 
them in the
+  // IRModule so that the various executors can see which function names need 
to be
+  // retrieved. They may, however, have been renamed.
+  compiler->AddExterns(updated_module);
+
+  // Add the lowered functions.
+  IRModule lowered_module = compiler->GetLoweredFunctions();
+  VLOG(1) << "capturing " << lowered_module->functions.size() << " new lowered 
functions";
+  for (const auto& kv : lowered_module->functions) {
+    if (updated_module->ContainGlobalVar(kv.first->name_hint)) {
+      LOG(FATAL) << "duplicate bindings for '" << kv.first->name_hint
+                 << "'. Existing is:" << std::endl
+                 << PrettyPrint(updated_module->Lookup(kv.first->name_hint)) 
<< std::endl
+                 << "while new is:" << std::endl
+                 << PrettyPrint(kv.second);
+    }
+    updated_module->Add(kv.first, kv.second);
+  }
 
-  // Copy the lowered functions into the return module
-  updated_module->Update(compiler->GetLoweredFunctions());
+  // Invoke external codegen for all Functions in the cache tagged with 
"Compiler", and
+  // annotate the module with the resulting runtime modules.
+  // TODO(mbs): runtime modules should be first class rather than attributes.
+  Array<runtime::Module> external_mods =
+      module->GetAttr<Array<runtime::Module>>("external_mods", 
Array<runtime::Module>()).value();
+  Array<runtime::Module> new_external_mods = 
compiler->LowerExternalFunctions();
+  VLOG(1) << "capturing " << external_mods.size() << " existing and " << 
new_external_mods.size()
+          << " new external modules";
+  for (const auto& mod : new_external_mods) {
+    external_mods.push_back(mod);  // copy-on-write.
+  }
 
-  // Annotate the module with C Device API context mapping, the external 
modules and function info
-  // this is until we have Target's annotated for the C Device API
+  // Annotate the module with C Device API context mapping (this is until we 
have Target's

Review comment:
       typo: Target's -> Targets

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -427,26 +414,49 @@ class MakeShapeFunc : public 
backend::MemoizedExprTranslator<Array<te::Tensor>>
       candidate_name = truncated_name.str();
     }
 
-    // Set all the inputs correctly.
+    // Set all the inputs correctly, and accumulate their types from the 
p.o.v. of the
+    // shape function rather than the primitive it is derived for.
     Array<te::Tensor> inputs;
+    Array<Type> shape_function_arg_types;
     for (auto param : prim_func->params) {
       int state = param_states_[param];
       shape_func_param_states.push_back(IntImm(DataType::Int(32), state));
       if (state & kNeedInputData) {
+        // Pass the primitive arguments directly (though in flattened form and 
on the host)
         for (auto t : param_data_[param]) {
           inputs.push_back(t);
+          shape_function_arg_types.push_back(TensorType(t->GetShape(), 
t->GetDataType()));
         }
       }
       if (state & kNeedInputShape) {
+        // Pass the shapes of the primitive arguments (also on the host)
         for (auto t : param_shapes_[param]) {
           inputs.push_back(t);
+          shape_function_arg_types.push_back(TensorType(t->GetShape(), 
t->GetDataType()));
         }
       }
     }
 
+    // TODO(mbs): This should be the definitive global by which the PrimFunc 
is known and
+    // no  other GlobalVar ctors should appear inside the lowering machinery.
     auto func_name = renamer(candidate_name);
     auto prim_fn_gvar = GlobalVar(func_name);
-    prim_fn_gvar->checked_type_ = prim_func->checked_type();
+
+    // Gather the result types, again from the p.o.v. of the shape function 
rather than
+    // the primitive it is derived for.
+    Array<Type> shape_function_res_types;
+    for (const auto& t : outputs) {
+      shape_function_res_types.push_back(TensorType(t->GetShape(), 
t->GetDataType()));
+    }
+
+    // Assign the shape function it's true type.

Review comment:
       typo: it's -> its

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -471,10 +481,17 @@ class MakeShapeFunc : public 
backend::MemoizedExprTranslator<Array<te::Tensor>>
     With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
 
     std::unordered_map<te::Tensor, tir::Buffer> binds;
-    IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, 
binds);
-
+    IRModule lowered_module = tvm::LowerSchedule(schedule, all_args, 
func_name, binds);
+
+    IRModule fixed_lowered_module;
+    for (const auto& kv : lowered_module->functions) {

Review comment:
       Why do we need to fix the lowered module?

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -1051,6 +995,26 @@ transform::Sequential MemoryOpt(const SEScope& 
cpu_se_scope) {
   return transform::Sequential(std::move(pass_seqs));
 }
 
+transform::Sequential VMCompiler::LowerOperators(const SEScope& host_se_scope) 
{
+  Array<Pass> pass_seqs;

Review comment:
       Might be good to rename this to fuse and lower (or something that 
mentions fusion), in the places you used it you also have a note that it fuses 
as well as lowers.

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -481,50 +484,12 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
     this->last_register_ = merge_register;
   }
 
-  void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {

Review comment:
       Cool that you got rid of this!!

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -1102,7 +1066,24 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
   pass_seqs.push_back(transform::ToANormalForm());
   pass_seqs.push_back(transform::InferType());
   pass_seqs.push_back(transform::LambdaLift());
-  pass_seqs.push_back(transform::InlinePrimitives());
+
+  // Eliminate dead-code before we lower. We don't track the purity of 
PrimFuncs, thus after
+  // lowering all calls to lowered functions will be kept.
+  pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
+  pass_seqs.push_back(transform::LabelOps());
+
+  // Lower all function's annotated as "primitive" by FuseOps.

Review comment:
       typo: function's -> functions

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -395,7 +400,7 @@ class TypeInferencer : private ExprFunctor<Type(const 
Expr&)>,
   //
   // The result will be the return type of the operator.
   Type PrimitiveCall(const FuncTypeNode* op, Array<Type> arg_types, const 
Attrs& attrs,
-                     const Span& span) {
+                     const Span& span, const Expr& expr) {

Review comment:
       I think you don't use expr in this function at all? Also can you change 
op to func_type_node :-)

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -765,7 +765,8 @@ class DeviceCapturer : public ExprMutator {
 
   IRModule Capture() {
     VLOG_CONTEXT << "CaptureDevices";
-    IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), 
mod_->source_map);
+    IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), 
mod_->source_map,

Review comment:
       Hmm, maybe we should add a WithFields for IRModule. 

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -481,50 +484,12 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
     this->last_register_ = merge_register;
   }
 
-  void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
-    // Lower shape function
-    tec::CCacheKey key(func, host_se_scope_->target);
-    auto cfunc = context_->compiler->LowerShapeFunc(key);
-    int op_index = -1;
-    // pick the only function inside the context
-    ICHECK_EQ(cfunc->funcs->functions.size(), 1);
-    auto pfunc = 
Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
-    if (context_->seen_funcs.count(pfunc) == 0) {
-      op_index = context_->cached_funcs.size();
-      context_->cached_funcs.push_back(cfunc);
-      context_->seen_funcs[pfunc] = op_index;
-    } else {
-      op_index = context_->seen_funcs[pfunc];
-    }
-
-    // Prepare input and output registers
+  void EmitInvokeTVMOp(const Expr& func, const Expr& inputs, const Expr& 
outputs,

Review comment:
       Why don't we need to pass SEScope in here anymore?

##########
File path: src/relay/op/memory/device_copy.cc
##########
@@ -117,15 +117,5 @@ DeviceCopyProps GetDeviceCopyProps(const Expr& expr) {
   return {};
 }
 
-DeviceCopyProps GetLoweredDeviceCopyProps(const CallLoweredProps& props) {

Review comment:
       glad to see this simplification!

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -481,206 +557,217 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
    */
   Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type> 
type_args, Span span,
                        Target target) {
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      // BYOC flow.
-      CCacheKey key = CCacheKey(func, target);
-      CachedFunc ext_func = compiler_->Lower(key, module_name_);
-      ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
-                                 << ext_func->prim_fn_var->name_hint;
-
-      // TODO(@areusch, @jroesch): this metadata is for AOT, this should be 
our interface for AOT
-      Map<GlobalVar, tir::PrimFunc> prim_fns;
-      relay::Function func_with_metadata = func;
-      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
ext_func->prim_fn_var);
-      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", 
prim_fns);
-      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
ext_func->target);
-
-      // Provide a callback hook which allows one-level up code generators to
-      // act when we process a function.
-      this->process_fn_(func_with_metadata);
-
-      // TODO(mbs): Dynamic shapes?
-      // TODO(@mbs, electriclilies): Make extern functions explicit
-      return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args, 
span);
-
-    } else {
-      // Non-External Relay Function
-      CCacheKey key = CCacheKey(func, target);
-      CachedFunc lowered_func = compiler_->Lower(key, module_name_);
-
-      // Collect all the lowered functions produced for this primitive 
function.
-      Map<GlobalVar, tir::PrimFunc> prim_fns;
-      Array<GlobalVar> all_prim_fn_vars;
-      for (auto prim_fn : lowered_func->funcs->functions) {
-        CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
-        prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
-        all_prim_fn_vars.push_back(prim_fn.first);
-      }
-
-      // TODO(@areusch, @jroesch): this metadata is for AOT, this should be 
our interface for AOT
-      relay::Function func_with_metadata = func;
-      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
lowered_func->prim_fn_var);
-      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", 
prim_fns);
-      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
lowered_func->target);
-
-      // Provide a callback hook which allows one-level up code generators to
-      // act when we process a function.
-      this->process_fn_(func_with_metadata);
-
-      auto call_lowered_attrs = make_object<CallLoweredAttrs>();
-      if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
-        call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
-      }
-
-      DeviceCopyProps props = GetDeviceCopyProps(func);
-      if (props.body.defined()) {
-        // Record the device copy source and destination scopes so the device 
planner can
-        // still follow along even after lowering.
-        call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope);
-        call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope);
+    CCacheKey key = CCacheKey(func, target);
+    CachedFunc cfunc = compiler_->Lower(key, module_name_);
+    ICHECK(cfunc.defined());
+
+    auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
+
+    // Add some metadata on top of the *original function* and invoke the 
callback so it can
+    // be captured.
+    // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our 
interface for AOT
+    Map<GlobalVar, tir::PrimFunc> prim_fns;
+    Array<GlobalVar> all_prim_fn_vars;
+    for (const auto& kv : cfunc->funcs->functions) {
+      if (opt_compiler) {
+        // We expect just the original func but with just the ExternalSymbol 
attribute signalling
+        // the function (will be) compiled externally.
+        ICHECK(kv.second.as<FunctionNode>())
+            << PrettyPrint(kv.first) << " must be bound to an (external) 
Function";
+      } else {
+        // We expect one or more PrimFuncs, one of which corresponds to 'the' 
lowered primitive
+        // (and the rest in support of that via tir::Calls).
+        ICHECK(kv.second.as<tir::PrimFuncNode>())
+            << PrettyPrint(kv.first) << " must be bound to a PrimFunc";
+        prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second));
+        all_prim_fn_vars.push_back(kv.first);
       }
+    }
+    Function func_with_metadata = func;
+    func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
cfunc->prim_fn_var);
+    func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
+    func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
cfunc->target);
+    this->process_fn_(func_with_metadata);
+
+    auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+
+    // Non-External Relay Function
+    // TODO(mbs): "reshape" cleanup.
+    if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) {
+      call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
+    }
 
-      call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
-      call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
-
-      if (IsDynamic(func->ret_type)) {
-        // Also lower the dynamic shape function.
-        // Shape function keys use the underlying primitive function as their 
'function',
-        // but the generic 'cpu' target as the target since all shape 
functions run
-        // on the host cpu irrespective of where the primitive runs.
-        // TODO(mbs): Cleanup target handling.
-        Target shape_target("llvm");
-        CCacheKey shape_key(func, shape_target);
-        CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
-        // Capture the shape function's global var and parameters 'states' in 
call
-        // annotations so calling convention can be recovered.
-        // TODO(mbs): Capture all this as part of a 'call into TIR' construct 
once available.
-        // The way the shape function calling convention is derived and passed 
to call sites
-        // via the 'parameter states' could be improved.
-        call_lowered_attrs->metadata.Set("prim_shape_fn_var", 
lowered_shape_func->prim_fn_var);
-        call_lowered_attrs->metadata.Set("prim_shape_fn_states",
-                                         
lowered_shape_func->shape_func_param_states);
-        call_lowered_attrs->metadata.Set(
-            "prim_shape_fn_num_inputs",
-            Integer(static_cast<int>(lowered_shape_func->inputs.size())));
-        call_lowered_attrs->metadata.Set(
-            "prim_shape_fn_num_outputs",
-            Integer(static_cast<int>(lowered_shape_func->outputs.size())));
-        Array<GlobalVar> all_prim_shape_fn_vars;
-        for (auto prim_shape_fn : lowered_shape_func->funcs->functions) {
-          CHECK(prim_shape_fn.second.as<tir::PrimFuncNode>()) << "must be a 
prim fn";
-          all_prim_shape_fn_vars.push_back(prim_shape_fn.first);
-        }
-        call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", 
all_prim_shape_fn_vars);
+    call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
+    call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
+
+    if (IsDynamic(func->ret_type)) {
+      // Also lower the companion dynamic shape function.
+      // Shape function keys use the underlying primitive function as their 
'function',
+      // but the generic 'cpu' target as the target since all shape functions 
run
+      // on the host cpu irrespective of where the primitive runs.
+      CCacheKey shape_key(func, host_se_scope_->target);
+      CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
+
+      // Capture the shape function's global var and parameters 'states' in 
call
+      // annotations so calling convention can be recovered.
+      // TODO(mbs): Shape cleanup.
+      call_lowered_attrs->metadata.Set("prim_shape_fn_var", 
lowered_shape_func->prim_fn_var);
+      call_lowered_attrs->metadata.Set("prim_shape_fn_states",
+                                       
lowered_shape_func->shape_func_param_states);
+      call_lowered_attrs->metadata.Set(
+          "prim_shape_fn_num_inputs", 
Integer(static_cast<int>(lowered_shape_func->inputs.size())));
+      call_lowered_attrs->metadata.Set(
+          "prim_shape_fn_num_outputs",
+          Integer(static_cast<int>(lowered_shape_func->outputs.size())));
+      Array<GlobalVar> all_prim_shape_fn_vars;
+      for (const auto& kv : lowered_shape_func->funcs->functions) {
+        CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+        all_prim_shape_fn_vars.push_back(kv.first);
       }
-      return CallLowered(lowered_func->prim_fn_var, visited_args, 
Attrs(call_lowered_attrs),
-                         type_args, span);
+      call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", 
all_prim_shape_fn_vars);
     }
+
+    return CallLowered(cfunc->prim_fn_var, std::move(visited_args), 
Attrs(call_lowered_attrs),
+                       type_args, std::move(span));
   }
 
   std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) 
final {
     Var new_var = Downcast<Var>(Mutate(var));
     Expr new_value = Mutate(value);
     BaseFunc prim_func = ResolveToPrimitive(new_value);
 
-    if (prim_func.defined() && !prim_func->IsInstance<tir::PrimFuncNode>()) {
-      // Remember let var is bound to (possibly indirectly) a non-tir 
primitive.
-      Function func = Downcast<Function>(prim_func);
-      primitive_functions_.emplace(var, func);
+    if (prim_func.defined()) {
+      // Remember let var is bound (possibly indirectly) to a primitive 
function.
+      primitive_functions_.emplace(var.get(), prim_func);
     }
     return {new_var, new_value};
   }
 
   Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* 
post_let_node) final {
     BaseFunc prim_func = ResolveToPrimitive(post_let_node->value);
-    if (prim_func.defined() && !prim_func->IsInstance<tir::PrimFuncNode>()) {
+    if (prim_func.defined()) {
       // Leaving let var scope
-      primitive_functions_.erase(pre_let_node->var);
+      primitive_functions_.erase(pre_let_node->var.get());
     }
     return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node);
   }
 
   Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override {
-    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
-      // Nothing to lower inside primitive functions.
+    if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
+        function_node->GetAttr<String>(attr::kExternalSymbol)) {
+      // Nothing to lower inside primitive/external functions.
       return GetRef<Function>(function_node);
     } else {
       return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node);
     }
   }
 
   Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
-    // Passes before lowering might insert a call_lowered to call a function 
that has already
-    // been lowered. Therefore we might see call_lowered ops here, but we 
don't need to do anything
-    // because ResolveToPrimitive returns null for all calls where the 
call_node->op is an OpNode
-    Call call = GetRef<Call>(call_node);
-
-    // Look for (indirect) calls to primitives.
-    BaseFunc prim_func = ResolveToPrimitive(call_node->op);
-    if (!prim_func.defined()) {
-      // Not a call_node to a primitive function.
-      if (const FunctionNode* fn = call_node->op.as<FunctionNode>()) {
-        this->process_fn_(GetRef<Function>(fn));
+    // We can see five forms of calls:
+    //  1. A 'normal' Relay call to a Function with the "primitive" attribute. 
We will need
+    //     to lower that to a global PrimFunc and rewrite the call to:
+    //       call_lowered(@new_global, (arg1, ..., argn), <attributes>)
+    //     However there are a few special forms which are excluded from this 
treatment, see
+    //     below.
+    //  2. A 'normal' Relay call to a Function with the "compiler" attribute. 
We will need
+    //     to invoke the appropriate BYOC toolchain function to yield a 
runtime module and
+    //     rewrite the call to the same form as above.
+    //  3. A 'normal' Relay call to a PrimFunc which has already been supplied 
via a global
+    //     definition. We rewrite to use the call_lowered form, but otherwise 
nothing else
+    //     needs to be done.
+    //  4. A 'normal' Relay call to a Relay Function without any special 
attribute. These
+    //     calls are not changed.
+    //  5. A call_lowered call from an earlier invocation of this pass.
+    // Note that ResolveToPrimitive will yield non-null only for cases 1-3.
+
+    // Look for (possibly indirect) calls to primitives.
+    BaseFunc primitive_func = ResolveToPrimitive(call_node->op);
+    if (!primitive_func.defined()) {
+      // Not a call to a primitive function we need to rewrite.
+      if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+        process_fn_(GetRef<Function>(function_node));
       }
-      return ExprMutator::VisitExpr_(call_node);
+      return DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
     }
 
-    // Similarly transform arguments.
-    Array<Expr> visited_args;
+    // Prepare the arguments.
+    Array<Expr> new_args;
     for (const auto& arg : call_node->args) {
-      visited_args.push_back(VisitExpr(arg));
+      new_args.push_back(VisitExpr(arg));
     }
 
-    // Already lowered by other means so we don't need to mutate
+    // Special case: device_copies are left as calls to primitive operators
+    // (thus undoing FuseOps) so that each backend can handle them directly.
+    // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just 
leave device_copy alone.
+    if (const auto* function_node = primitive_func.as<FunctionNode>()) {
+      DeviceCopyProps device_copy_props = 
GetDeviceCopyProps(function_node->body);
+      if (device_copy_props.body.defined()) {
+        ICHECK_EQ(new_args.size(), 1);
+        return DeviceCopy(new_args[0], device_copy_props.src_se_scope,
+                          device_copy_props.dst_se_scope);
+      }
+    }
+
+    // Special case: If already lowered by other means then so we don't need 
to mutate
     // the call but we do need to mutate the arguments
-    if (prim_func->IsInstance<tir::PrimFuncNode>()) {
+    if (const auto* prim_func_node = primitive_func.as<tir::PrimFuncNode>()) {
       // Function should already be Target annotated by this point
       // but the TE Compiler metadata is still needed for the callback
       // TODO(Mousius) - Robustify this to not assume we're in the GlobalVar 
for Target Hooks
       GlobalVar prim_func_var = Downcast<GlobalVar>(call_node->op);
-      tir::PrimFunc downcast_prim_func = Downcast<tir::PrimFunc>(prim_func);
+      tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
 
-      Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var, 
downcast_prim_func}};
-      tir::PrimFunc func_with_metadata =
-          WithAttrs(downcast_prim_func, {
-                                            {"prim_fn_var", prim_func_var},
-                                            {"prim_funcs", prim_fns},
-                                        });
+      Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var, prim_func}};
+      tir::PrimFunc func_with_metadata = WithAttrs(prim_func, {
+                                                                  
{"prim_fn_var", prim_func_var},
+                                                                  
{"prim_funcs", prim_fns},
+                                                              });
+
+      ICHECK(!IsDynamic(call_node->checked_type()));
+      auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+      call_lowered_attrs->metadata.Set("relay_attrs", primitive_func->attrs);
 
-      this->process_fn_(func_with_metadata);
-      return Call(call_node->op, visited_args, call_node->attrs);
+      process_fn_(func_with_metadata);
+      return CallLowered(call_node->op, std::move(new_args), 
Attrs(std::move(call_lowered_attrs)),
+                         call_node->type_args, call_node->span);
     }
 
+    // Typical case: call to fused primitive Relay Function.
     // Find the desired target device.
     Target target;
-    if (prim_func->GetAttr<String>(attr::kCompiler).defined()) {
+    if (primitive_func->GetAttr<String>(attr::kCompiler).defined()) {
       // The generic 'external device' target.
+      // TODO(mbs): Retire once replaced unified BYOC compiler and target 
macihnery.

Review comment:
       typo: macihnery -> machinery




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to