This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ec24ae60a0 [BYOC] RelayToTIR custom codegen passes can still depend on 
dynamic shape functions (#11619)
ec24ae60a0 is described below

commit ec24ae60a028f5aae0fa2f1d8a668eb6bf366414
Author: Mark Shields <[email protected]>
AuthorDate: Thu Jun 9 21:22:49 2022 -0700

    [BYOC] RelayToTIR custom codegen passes can still depend on dynamic shape 
functions (#11619)
    
    In #11474 I got ready to switch CUTLASS from function-at-a-time to 
IRModule-at-a-time compilation.
    However my approach didn't handle dynamic shape functions, so I adjust it 
here.
    
    The idea is still that such passes will leave behind
    calls to 'extern' functions. However, converting those
    calls to 'call_lowered' form in
    MarkCompilerFunctionsAsExtern is too soon since only
    the TECompiler knows how to capture all the attributes
    necessary to support dynamic shape functions.
    
    So stop doing that in MarkCompilerFunctionsAsExtern and
    instead support this case properly in the TECompiler.
    
    While there try to chip away at the chronic lack of structure in 
te_compiler.cc. Every little bit helps.
    
    Add a basic unit test.
---
 src/relay/backend/aot_executor_codegen.cc          |   8 +-
 src/relay/backend/graph_executor_codegen.cc        |  27 +-
 src/relay/backend/interpreter.cc                   |   3 +-
 src/relay/backend/te_compiler.cc                   | 329 ++++++++++++++-------
 src/relay/backend/te_compiler.h                    |  32 +-
 src/relay/backend/vm/compiler.cc                   |  24 +-
 src/relay/transforms/compiler_function_utils.cc    |  51 ----
 src/relay/transforms/compiler_function_utils.h     |  11 +-
 tests/python/relay/backend/test_pass_lower_te.py   | 241 +++++++++++++++
 .../transform/test_compiler_function_utils.py      |   5 +-
 10 files changed, 503 insertions(+), 228 deletions(-)

diff --git a/src/relay/backend/aot_executor_codegen.cc 
b/src/relay/backend/aot_executor_codegen.cc
index 167afd2c5f..381cfa0c9d 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -1064,9 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
 
     mod = transform::ToANormalForm()(mod);
 
-    IRModule lowered_mod = tec::LowerTEPass(
-        mod_name,
-        [this, workspace_byte_alignment](BaseFunc func) {
+    IRModule lowered_mod =
+        tec::LowerTE(mod_name, config_, [this, 
workspace_byte_alignment](BaseFunc func) {
           // We need to maintain the constant map for external
           // functions so we pass this processing function which
           // allows us to process each function as we lower it.
@@ -1078,8 +1077,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
           // execute as a further pass, instead writing data to the
           // lowering process directly.
           tec::UpdateFunctionMetadata(func, this->function_metadata_, 
workspace_byte_alignment);
-        },
-        config_)(mod);
+        })(mod);
 
     auto lowered_main = lowered_mod->Lookup("main");
     auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
diff --git a/src/relay/backend/graph_executor_codegen.cc 
b/src/relay/backend/graph_executor_codegen.cc
index 7dba23803f..af426e5c71 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -217,22 +217,19 @@ class GraphExecutorCodegen : public 
backend::MemoizedExprTranslator<std::vector<
       mod = WithAttr(mod, "main_func_info", func_info);
     }
 
-    IRModule lowered_mod = tec::LowerTEPass(
-        mod_name_,
-        [this](BaseFunc func) {
-          // We need to maintain the constant map for external
-          // functions so we pass this processing function which
-          // allows us to process each function as we lower it.
-          if (func->GetAttr<String>(attr::kCompiler).defined()) {
-            UpdateConstants(func, &params_);
-          }
+    IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc 
func) {
+      // We need to maintain the constant map for external
+      // functions so we pass this processing function which
+      // allows us to process each function as we lower it.
+      if (func->GetAttr<String>(attr::kCompiler).defined()) {
+        UpdateConstants(func, &params_);
+      }
 
-          // TODO(@areusch, @jroesch): We should refactor this to
-          // execute as a further pass, instead writing data to the
-          // lowering process directly.
-          tec::UpdateFunctionMetadata(func, this->function_metadata_);
-        },
-        config_)(mod);
+      // TODO(@areusch, @jroesch): We should refactor this to
+      // execute as a further pass, instead writing data to the
+      // lowering process directly.
+      tec::UpdateFunctionMetadata(func, this->function_metadata_);
+    })(mod);
 
     Optional<backend::FunctionInfo> main_func_info =
         lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 9661040eab..65a0fdc948 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -960,8 +960,7 @@ IRModule Prepare(IRModule mod, const CompilationConfig& 
config) {
        // eta expand to support constructors in argument position.
        transform::EtaExpand(
            /*expand_constructor=*/true, /*expand_global_var=*/false),
-       transform::InferType(),
-       tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op 
*/ }, config)});
+       transform::InferType(), tec::LowerTE(/*module_name=*/"intrp", config)});
 
   transform::PassContext pass_ctx = transform::PassContext::Current();
   With<transform::PassContext> ctx(pass_ctx);
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index c78f3abd6e..e9491b0a89 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -17,6 +17,76 @@
  * under the License.
  */
 
+/*!
+ * \file relay/backend/te_compiler.cc
+ * \brief Manages the transition from Relay "Primitive" \p Functions to TIR \p 
PrimFuncs. Also
+ * handles invocation of external codegen.
+ *
+ * \p LowerTEPass handles the following (as a monolithic blob of code):
+ *
+ *  - Most importantly, any function with the "Primitive" attribute is first 
converted to TE by
+ *    \p LowerToTECompute (see te_compiler_cache.cc) using each operator's 
'compute' function.
+ *    The TE is then 'scheduled' to TIR using the 'anchor' operator's 
'schedule' function. Both
+ *    of those functions come from the \p OpStrategy returned by the Python
+ *    'relay.backend.lower_call' function (see te_compiler.py).
+ *    The TIR is packed as a \p PrimFunc and introduced as a new global 
function. Calls to the
+ *    original "Primitive" function are then rewritten to the form:
+ *    \code
+ *      call_lowered(@new_global, (... original args...), attributes)
+ *    \endcode
+ *
+ *  - The above "Primitive" function can appear:
+ *     - As a global function
+ *     - As a let-bound function
+ *     - As an inline function, ie the 'op' of calls.
+ *    In all three cases it is possible for the same "Primitive" function to 
be called multiple
+ *    times, and that sharing must be respected.
+ *
+ *  - "Primitive" functions must have a "global_symbol" attribute matching 
their desired or
+ *    existing global name. Care is taken to ensure GlobalVars with the same 
name are shared.
+ *
+ *  - It is possible for multiple structurally equal "Primitive" functions to 
appear in the same
+ *    \p IRModule. Only one implementation should be generated, and all calls 
should share that
+ *    implementation.
+ *
+ *  - When later converting to DPS (see memory_alloc.cc) we must handle 
functions who's result
+ *    tensor shapes depend at runtime on the input tensor shapes and/or data.
+ *     - That dependency is first described in TE form (see \p MakeShapeFunc in
+ *       te_compiler_cache.cc), then scheduled to yield a 'dynamic shape 
function' \p PrimFunc.
+ *       This relies on each operator's "FShapeFunc" and "TShapeDataDependent" 
attributes.
+ *       Since shapes are rank-1 tensors everything can be reflected back down 
into the regular
+ *       TE/TIR forms.
+ *     - Then the call_lowered attributes must record everything about the 
dynamic shape function
+ *       later needed by memory_alloc.cc. We call this 'cross linking' the 
call with the shape
+ *       function.
+ *
+ *  - Two external codegen mechanisms are supported, both triggered by 
"Primitive" functions which
+ *    also have a "Compiler" attribute bound to $compiler:
+ *     - Function-at-a-time (old style): The primitive function is passed to 
the function
+ *       registered as 'relay.ext.$compiler'. The function returns a 
runtime::Module which
+ *       should return true for \p ImplementsFunction for the function's 
global name. That
+ *       module is added to the IRModule's "external_mods" attributes.
+ *     - IRModule-at-a-item (new style): The \p RelayToTIRTargetHook sub-pass 
looks for
+ *       $compiler names which correspond to TargetKind names with a \p 
RelayToTIR attribute.
+ *       The \p Pass bound to that attribute is run, and each such 'custom' 
pass can do what
+ *       it likes, including replacing Functions with PrimFuncs, or adding new 
runtime::Modules
+ *       to the IRModule's "external_mods" attribute.
+ *
+ *  - Calls to functions added by external codegen are also rewritten to 
call_lowered form, and
+ *    may also require cross-linking to dynamic shape functions. However, 
since the functions
+ *    are/will be implemented by a runtime::Module all the Relay type 
information is no longer
+ *    available. So the Relay definitions for these "Primitive" "Compiler" 
functions are retained
+ *    in the \p IRModule, but marked with the "Extern" attribute to signal the 
function is now
+ *    just for carrying metadata.
+ *
+ *  - Some operators are handled specially:
+ *     - 'reshape', since it's a no-op on the underlying tensor buffer, and 
this is handled by
+ *       condition tests in many passes.
+ *     - 'debug', since it's intercepted differently depending on runtimes.
+ *
+ * TODO(mbs): This desperately deserves a refactor to separate all these 
concerns. See Relax.
+ */
+
 #include "./te_compiler.h"
 
 #include <tvm/driver/driver_api.h>
@@ -222,7 +292,7 @@ class TECompilerImpl : public TECompilerNode {
         } else {
           // It is valid for the external codegen function to return null:
           //  - Unit tests can use it.
-          //  - The true compilation may have already been handled by a 
RelayToTIR custom hook pass
+          //  - The true compilation may have already been handled by a 
RelayToTIR custom pass
           //    on the Target's kind. The original Relay functions will be 
left in place so
           //    that we can capture that their function names are now 
externally defined.
           VLOG(1) << "Note that no external runtime module was generated by 
external codegen '"
@@ -566,100 +636,128 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
         return itr->second;
       }
     } else if (const auto* function_node = expr.as<FunctionNode>()) {
-      if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
-        // Not marked as primitive by FuseOps.
-        return {};
-      }
-      if (const auto* call_node = function_node->body.as<CallNode>()) {
-        if (call_node->op == debug_op_) {
-          // Debug 'primitives' are not lowered.
-          return {};
+      if (function_node->HasNonzeroAttr(attr::kExtern)) {
+        // We have a regular call to an 'extern' function. The call itself 
needs to be rewritten
+        // to call_lowered form, and any required dynamic shape functions 
generated and
+        // cross-linked.
+        return GetRef<Function>(function_node);
+      } else if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+        if (const auto* call_node = function_node->body.as<CallNode>()) {
+          if (call_node->op == debug_op_) {
+            // Debug 'primitives' are not lowered.
+            return {};
+          }
         }
+        // We have a regular call to a 'primitive' function (possibly with a 
'Compiler' attribute).
+        // We need to lower and rewrite the call.
+        return GetRef<Function>(function_node);
+      } else {
+        // Not marked as primitive during partitioning or TVM fusion.
+        return {};
       }
-      return GetRef<Function>(function_node);
     } else {
       return {};
     }
   }
 
   /*!
-   * \brief Lowers the primitive function \p func to TIR for ultimate execution
-   * on a device with configuration \p target. Returns the global var bound
-   * to the TIR implementation, and attributes to attach to the call to 
identify it as
-   * a TIR call.
+   * \brief Returns a 'call_lowered' call to \p prim_fn_var with \p args and 
\p span with all the
+   * required attributes filled in. Generally \p prim_fn_var will correspond 
to the lowered or
+   * externally codegen-ed form of \p original_function, where \p 
lowered_functions binds all
+   * the required lowered functions.
+   *
+   * The call's attributes will capture:
+   *  - Any attributes on the original_function.
+   *  - All the lowered functions.
+   *    TODO(mbs): Pretty sure that's no longer needed.
+   *  - Details needed to cross-link the call to it's dynamic shape function, 
if any.
    */
-  Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Span span, 
Target target) {
-    CCacheKey key = CCacheKey(func, target);
-    CachedFunc cfunc = compiler_->Lower(key, module_name_);
-    ICHECK(cfunc.defined());
-
-    auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
+  Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar& 
prim_fn_var,
+                       Array<Expr> args, Span span, const Target& target,
+                       const Map<GlobalVar, BaseFunc>& lowered_functions) {
+    auto opt_compiler = original_function->GetAttr<String>(attr::kCompiler);
 
     // Add some metadata on top of the *original function* and invoke the 
callback so it can
     // be captured.
     // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our 
interface for AOT
     Map<GlobalVar, tir::PrimFunc> prim_fns;
     Array<GlobalVar> all_prim_fn_vars;
-    for (const auto& kv : cfunc->funcs->functions) {
+    for (const auto& kv : lowered_functions) {
       if (opt_compiler) {
-        // We expect just the original func but with just the ExternalSymbol 
attribute signaling
-        // the function (will be) compiled externally.
+        // We expect the original function to have just the "Extern" attribute 
signaling the
+        // function (will be) compiled externally.
         ICHECK(kv.second.as<FunctionNode>())
             << PrettyPrint(kv.first) << " must be bound to an (external) 
Function";
       } else {
-        // We expect one or more PrimFuncs, one of which corresponds to 'the' 
lowered primitive
-        // (and the rest in support of that via tir::Calls).
+        // We expect one or more PrimFuncs, one of which corresponds to 'the' 
lowered primitive,
+        // and the rest are in support of that via tir::Calls.
         ICHECK(kv.second.as<tir::PrimFuncNode>())
             << PrettyPrint(kv.first) << " must be bound to a PrimFunc";
         prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second));
         all_prim_fn_vars.push_back(kv.first);
       }
     }
-    Function func_with_metadata = func;
-    func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
cfunc->prim_fn_var);
-    func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
-    func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
cfunc->target);
-    this->process_fn_(func_with_metadata);
 
+    // Alas, WithAttr cannot work with base classes.
+    if (const auto* prim_func_node = original_function.as<te::PrimFuncNode>()) 
{
+      auto func_with_metadata = GetRef<te::PrimFunc>(prim_func_node);
+      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
prim_fn_var);
+      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", 
prim_fns);
+      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
target);
+      this->process_fn_(func_with_metadata);
+    } else {
+      const auto* function_node = original_function.as<FunctionNode>();
+      ICHECK(function_node);
+      auto func_with_metadata = GetRef<Function>(function_node);
+      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
prim_fn_var);
+      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", 
prim_fns);
+      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
target);
+      this->process_fn_(func_with_metadata);
+    }
+
+    // Now prepare the attributes of the call_lowered.
     CallLoweredAttrs call_lowered_attrs;
 
-    // Non-External Relay Function
     // TODO(mbs): "reshape" cleanup.
-    if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) {
+    if (!opt_compiler && 
original_function->HasNonzeroAttr(attr::kReshapeOnly)) {
       call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
     }
 
-    call_lowered_attrs.metadata.Set("relay_attrs", func->attrs);
+    call_lowered_attrs.metadata.Set("relay_attrs", original_function->attrs);
     call_lowered_attrs.metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
 
-    if (IsDynamic(func->ret_type)) {
-      // Also lower the companion dynamic shape function.
-      // Shape function keys use the underlying primitive function as their 
'function',
-      // but the generic 'cpu' target as the target since all shape functions 
run
-      // on the host cpu irrespective of where the primitive runs.
-      CCacheKey shape_key(func, config_->host_virtual_device->target);
-      CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
-
-      // Capture the shape function's global var and parameters 'states' in 
call
-      // annotations so calling convention can be recovered.
-      // TODO(mbs): Shape cleanup.
-      call_lowered_attrs.metadata.Set("prim_shape_fn_var", 
lowered_shape_func->prim_fn_var);
-      call_lowered_attrs.metadata.Set("prim_shape_fn_states",
-                                      
lowered_shape_func->shape_func_param_states);
-      call_lowered_attrs.metadata.Set("prim_shape_fn_num_inputs",
-                                      
Integer(static_cast<int>(lowered_shape_func->inputs.size())));
-      call_lowered_attrs.metadata.Set(
-          "prim_shape_fn_num_outputs",
-          Integer(static_cast<int>(lowered_shape_func->outputs.size())));
-      Array<GlobalVar> all_prim_shape_fn_vars;
-      for (const auto& kv : lowered_shape_func->funcs->functions) {
-        CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
-        all_prim_shape_fn_vars.push_back(kv.first);
+    if (const auto* function_node = original_function.as<FunctionNode>()) {
+      if (IsDynamic(function_node->ret_type)) {
+        // Create a dynamic shape function to calculate the expected shape of 
the results of
+        // the lowered function.
+        // Shape function keys use the original function as their 'function', 
but the generic 'cpu'
+        // target as the target since all shape functions run on the host cpu 
irrespective of where
+        // the primitive runs.
+        CCacheKey shape_key(GetRef<Function>(function_node), 
config_->host_virtual_device->target);
+        CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
+
+        // Capture the shape function's global var and parameters 'states' in 
call
+        // annotations so calling convention can be recovered.
+        // TODO(mbs): Shape cleanup.
+        call_lowered_attrs.metadata.Set("prim_shape_fn_var", 
lowered_shape_func->prim_fn_var);
+        call_lowered_attrs.metadata.Set("prim_shape_fn_states",
+                                        
lowered_shape_func->shape_func_param_states);
+        call_lowered_attrs.metadata.Set(
+            "prim_shape_fn_num_inputs",
+            Integer(static_cast<int>(lowered_shape_func->inputs.size())));
+        call_lowered_attrs.metadata.Set(
+            "prim_shape_fn_num_outputs",
+            Integer(static_cast<int>(lowered_shape_func->outputs.size())));
+        Array<GlobalVar> all_prim_shape_fn_vars;
+        for (const auto& kv : lowered_shape_func->funcs->functions) {
+          CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+          all_prim_shape_fn_vars.push_back(kv.first);
+        }
+        call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", 
all_prim_shape_fn_vars);
       }
-      call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", 
all_prim_shape_fn_vars);
     }
 
-    return CallLowered(cfunc->prim_fn_var, std::move(visited_args), 
std::move(call_lowered_attrs),
+    return CallLowered(prim_fn_var, std::move(args), 
std::move(call_lowered_attrs),
                        std::move(span));
   }
 
@@ -697,43 +795,51 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
   }
 
   Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
-    // We can see five forms of calls:
-    //  1. A 'normal' Relay call to a Function with the "primitive" attribute. 
We will need
-    //     to lower that to a global PrimFunc and rewrite the call to:
+    // We can see six forms of calls:
+    //  1. A 'normal' Relay call to a Function with the "Primitive" attribute 
and not "Compiler"
+    //     attribute. We will need to lower that to a global PrimFunc and 
rewrite the call to:
     //       call_lowered(@new_global, (arg1, ..., argn), <attributes>)
-    //     However there are a few special forms which are excluded from this 
treatment, see
-    //     below.
-    //  2. A 'normal' Relay call to a Function with the "compiler" attribute. 
We will need
-    //     to invoke the appropriate BYOC toolchain function to yield a 
runtime module and
-    //     rewrite the call to the same form as above.
-    //  3. A 'normal' Relay call to a PrimFunc which has already been supplied 
via a global
-    //     definition. We rewrite to use the call_lowered form, but otherwise 
nothing else
+    //     If needed, the call needs to be cross-linked with any dynamic shape 
functions.
+    //     (However, some primitives are special and handled separately.)
+    //  2. A 'normal' Relay call to a Function with the "Primitive" and 
"Compiler" attributes. We
+    //     will need to invoke the "relay.ext.<compiler>" function to yield a 
runtime module, and
+    //     rewrite the call to the same form as above. Dynamic shape function 
cross-linking may
+    //     also be needed.
+    //  3. A 'normal' Relay call to a Function with the "Extern" attribute. 
This function has
+    //     already been compiled by an external codegen and a definition for 
it exists in some
+    //     runtime module. Again, we rewrite to call_lowered form, and 
cross-link with a dynamic
+    //     shape function if needed.
+    //  4. A 'normal' Relay call to a PrimFunc which has already been supplied 
via a global
+    //     definition. We rewrite those to use the call_lowered form, but 
otherwise nothing else
     //     needs to be done.
-    //  4. A 'normal' Relay call to a Relay Function without any special 
attribute. These
+    //  5. A 'call_lowered' call from an earlier invocation of this pass or 
otherwise deliberately
+    //     inserted. It has all the required attributes, and any associated 
dynamic shape function
+    //     has been generated and cross-linked. These calls are not changed.
+    //  6. A 'normal' Relay call to a Relay Function without any special 
attribute. These
     //     calls are not changed.
-    //  5. A call_lowered call from an earlier invocation of this pass.
-    // Note that ResolveToPrimitive will yield non-null only for cases 1-3.
+    //
+    // Note that ResolveToPrimitive will yield non-null only for cases 1-4.
+
+    // Prepare the arguments and op.
+    Array<Expr> new_args;
+    for (const auto& arg : call_node->args) {
+      new_args.push_back(VisitExpr(arg));
+    }
+    Expr new_op = VisitExpr(call_node->op);
 
     // Look for (possibly indirect) calls to primitives.
     BaseFunc primitive_func = ResolveToPrimitive(call_node->op);
     if (!primitive_func.defined()) {
-      // Not a call to a primitive function we need to rewrite.
+      // Cases 5 and 6: Leave as ordinary call.
       if (const auto* function_node = call_node->op.as<FunctionNode>()) {
         process_fn_(GetRef<Function>(function_node));
       }
-      return DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
-    }
-
-    // Prepare the arguments.
-    Array<Expr> new_args;
-    for (const auto& arg : call_node->args) {
-      new_args.push_back(VisitExpr(arg));
+      return WithFields(GetRef<Call>(call_node), std::move(new_op), 
std::move(new_args));
     }
 
-    // Special case: device_copies are left as calls to primitive operators
-    // (thus undoing FuseOps) so that each backend can handle them directly.
-    // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just 
leave device_copy
-    // alone.
+    // Special case for case 1: device_copies are left as calls to primitive 
operators
+    // so that each backend can handle them directly.
+    // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just 
leave device_copy alone.
     if (const auto* function_node = primitive_func.as<FunctionNode>()) {
       DeviceCopyProps device_copy_props = 
GetDeviceCopyProps(function_node->body);
       if (device_copy_props.body.defined()) {
@@ -743,33 +849,23 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
       }
     }
 
-    // Special case: If already lowered by other means then so we don't need 
to mutate
-    // the call but we do need to mutate the arguments
+    ICHECK(call_node->type_args.empty()) << "lowered functions cannot be 
polymorphic";
+
+    // Case 4: If the function has already been lowered we just need to update 
the call.
     if (const auto* prim_func_node = primitive_func.as<tir::PrimFuncNode>()) {
       // Function should already be Target annotated by this point
       // but the TE Compiler metadata is still needed for the callback
       // TODO(Mousius) - Robustify this to not assume we're in the GlobalVar 
for Target Hooks
-      GlobalVar prim_func_var = Downcast<GlobalVar>(call_node->op);
+      Optional<Target> opt_target = 
primitive_func->GetAttr<Target>(tvm::attr::kTarget);
+      ICHECK(opt_target.defined());
+      auto prim_fn_var = Downcast<GlobalVar>(call_node->op);
       tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
-
-      Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var, prim_func}};
-      tir::PrimFunc func_with_metadata = WithAttrs(prim_func, {
-                                                                  
{"prim_fn_var", prim_func_var},
-                                                                  
{"prim_funcs", prim_fns},
-                                                              });
-
-      ICHECK(!IsDynamic(call_node->checked_type()));
-      CallLoweredAttrs call_lowered_attrs;
-      call_lowered_attrs.metadata.Set("relay_attrs", primitive_func->attrs);
-
-      process_fn_(func_with_metadata);
-      ICHECK(call_node->type_args.empty()) << "lowered functions cannot be 
polymorphic";
-      return CallLowered(prim_func_var, std::move(new_args), 
std::move(call_lowered_attrs),
-                         call_node->span);
+      Map<GlobalVar, BaseFunc> prim_fns = {{prim_fn_var, prim_func}};
+      return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), 
call_node->span,
+                             opt_target.value(), prim_fns);
     }
 
-    // Typical case: call to fused primitive Relay Function.
-    // Find the desired target device.
+    // Determine the target for lowering or external codegen.
     Target target;
     Optional<String> opt_compiler = 
primitive_func->GetAttr<String>(attr::kCompiler);
     if (opt_compiler.defined()) {
@@ -791,10 +887,20 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
       ICHECK(target.defined());
     }
 
-    // Lower the primitive function for that target.
-    Function function = Downcast<Function>(primitive_func);
-    ICHECK(call_node->type_args.empty()) << "lowered functions cannot be 
polymorphic";
-    return MakeLoweredCall(function, std::move(new_args), call_node->span, 
target);
+    if (primitive_func->HasNonzeroAttr(attr::kExtern)) {
+      // Case 3: Function has already been compiled.
+      GlobalVar prim_fn_var = Downcast<GlobalVar>(call_node->op);
+      return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), 
call_node->span,
+                             target, /*lowered_functions=*/{});
+    } else {
+      // Cases 1 and 2: lower the primitive function for the desired target, 
possibly using external
+      // codegen.
+      CCacheKey key(Downcast<Function>(primitive_func), target);
+      CachedFunc cfunc = compiler_->Lower(key, module_name_);
+      ICHECK(cfunc.defined());
+      return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, 
std::move(new_args),
+                             call_node->span, target, cfunc->funcs->functions);
+    }
   }
 
   IRModule module_;
@@ -1046,6 +1152,7 @@ void UpdateFunctionMetadata(BaseFunc func,
   function_metadata.Set(prim_fn_var.value()->name_hint, fi);
 }
 
+/*! \brief Main lowering driving. */
 IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn 
process_fn,
                  CompilationConfig config) {
   TECompiler compiler(module);
@@ -1163,7 +1270,7 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
   return per_target_modules;
 }
 
-Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig 
complilation_config) {
+Pass LowerTE(String module_name, CompilationConfig complilation_config, 
ProcessFn process_fn) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule module,
                                                                             
PassContext ctx) {
     return LowerTE(module, module_name, process_fn, complilation_config);
@@ -1174,6 +1281,12 @@ Pass LowerTEPass(String module_name, ProcessFn 
process_fn, CompilationConfig com
        tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", 
{"InferType"}), InferType(),
        tvm::tir::transform::ExtractPrimFuncConstants()});
 }
+
+TVM_REGISTER_GLOBAL("relay.tec.LowerTE")
+    .set_body_typed([](String module_name, CompilationConfig 
compilation_config) {
+      return LowerTE(std::move(module_name), std::move(compilation_config));
+    });
+
 }  // namespace tec
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h
index 8312a20cb8..5d16da4b8b 100644
--- a/src/relay/backend/te_compiler.h
+++ b/src/relay/backend/te_compiler.h
@@ -18,8 +18,8 @@
  */
 
 /*!
- * \file relay/backend/tir_compiler.h
- *  * \brief Internal compilation layer which lowers Relay "primitive 
functions" to TIR PrimFns.
+ * \file relay/backend/te_compiler.h
+ * \brief Internal compilation layer which lowers Relay "primitive functions" 
to TIR PrimFns.
  *
  *
  * This represents the new design of the Relay compilation flow and will 
replace the interface
@@ -173,36 +173,22 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const 
IRModule& mod, const Compila
  */
 Map<Target, IRModule> GetPerTargetModules(IRModule mod);
 
-/*! \brief Lower an IRModule's primitive functions to TIR.
- *
- * This is the "back half" of the Relay compiler which lowers "primitive 
functions"
- * to TE expressions, schedules them, and then to TIR.
- *
- * \param module The IRModule.
- * \param memory_plan The memory plan used during lowering
- * \param module_name The name of this module
- * \param process_fn Callback allowing one-level up code generators to process
- * each function that we lower
- * \return The lowered module, see above.
- */
-IRModule LowerTE(
-    const IRModule& module, backend::StaticMemoryPlan memory_plan, const 
String& module_name,
-    ProcessFn process_fn = [](BaseFunc f) {});
+inline void DefaultProcessFn(BaseFunc) {}
 
 /*!
  * \brief Pass to lower an IRModule's primitive functions to TIR.
  *
  * This is the "back half" of the Relay compiler which lowers "primitive 
functions"
- * to TE expressions, schedules them, and then to TIR. It annotates all 
functions
- * with their target.
+ * to TE expressions, schedules them, and emits PrimFuncs.
  *
- * \param module_name The name of this module
- * \param process_fn Callback allowing one-level up code generators to process
- * each function that we lower
+ * \param module_name The name of this module, used as a prefix for generated 
globals.
  * \param config All available targets.
+ * \param process_fn Callback allowing one-level up code generators to process
+ * each function that we lower (default is no-op).
  * \returns The pass which lowers primitive functions to TIR
  */
-transform::Pass LowerTEPass(String module_name, ProcessFn process_fn, 
CompilationConfig config);
+transform::Pass LowerTE(String module_name, CompilationConfig config,
+                        ProcessFn process_fn = DefaultProcessFn);
 
 }  // namespace tec
 }  // namespace relay
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 48f12ea8aa..8820a403bf 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1040,13 +1040,11 @@ transform::Sequential 
VMCompiler::FuseAndLowerOperators(const CompilationConfig&
   // Give each "primitive" Function a hash.
   pass_seqs.push_back(LabelOps());
   // Lower "primitive" Functions to PrimFuncs and rewrite calls.
-  pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
-                                       [this](const BaseFunc& func) {
-                                         if 
(func->GetAttr<String>(attr::kCompiler).defined()) {
-                                           backend::UpdateConstants(func, 
&params_);
-                                         }
-                                       },
-                                       config));
+  pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config, 
[this](const BaseFunc& func) {
+    if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      backend::UpdateConstants(func, &params_);
+    }
+  }));
   // Since lowered functions are bound in the IRModule, we can now eliminate 
any unused
   // let-bound functions.
   pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
@@ -1091,13 +1089,11 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
   pass_seqs.push_back(transform::LabelOps());
 
   // Lower all functions annotated as "primitive" by FuseOps.
-  pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
-                                       [this](const BaseFunc& func) {
-                                         if 
(func->GetAttr<String>(attr::kCompiler).defined()) {
-                                           backend::UpdateConstants(func, 
&params_);
-                                         }
-                                       },
-                                       config_));
+  pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config_, 
[this](const BaseFunc& func) {
+    if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      backend::UpdateConstants(func, &params_);
+    }
+  }));
 
   // Since lowered functions are bound in the IRModule, we can now eliminate 
any unused
   // let-bound functions.
diff --git a/src/relay/transforms/compiler_function_utils.cc 
b/src/relay/transforms/compiler_function_utils.cc
index f22e9bd80d..3df07e4c57 100644
--- a/src/relay/transforms/compiler_function_utils.cc
+++ b/src/relay/transforms/compiler_function_utils.cc
@@ -81,42 +81,6 @@ class Outliner : public MixedModeMutator {
   IRModule mod_;
 };
 
-/*!
- * \brief Rewrite calls to global "Compiler" functions to use the 
'call_lowered' convention.
- */
-class CallRewriter : public MixedModeMutator {
- public:
-  CallRewriter(std::string compiler_filter, IRModule mod)
-      : compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {}
-
-  Expr Rewrite_(const CallNode* pre, const Expr& post) final {
-    Call new_call = Downcast<Call>(post);
-    if (const auto* global_var_node = new_call->op.as<GlobalVarNode>()) {
-      if (const auto* function_node =
-              
mod_->Lookup(GetRef<GlobalVar>(global_var_node)).as<FunctionNode>()) {
-        Optional<String> opt_compiler = 
function_node->GetAttr<String>(attr::kCompiler);
-        if (opt_compiler.defined() &&
-            (compiler_filter_.empty() || opt_compiler.value() == 
compiler_filter_)) {
-          Optional<String> opt_global_symbol =
-              function_node->GetAttr<String>(tvm::attr::kGlobalSymbol);
-          ICHECK(opt_global_symbol.defined());
-          GlobalVar global_symbol = 
mod_->GetGlobalVar(opt_global_symbol.value());
-          CallLoweredAttrs attrs;
-          attrs.metadata.Set("relay_attrs", new_call->attrs);
-          return CallLowered(global_symbol, new_call->args, attrs, 
new_call->span);
-        }
-      }
-    }
-    return post;
-  }
-
- private:
-  /*! \brief If non-empty, the "Compiler" attribute value to require on 
functions to outline. */
-  std::string compiler_filter_;
-  /*! \brief Module being rewritten. */
-  IRModule mod_;
-};
-
 }  // namespace
 
 GlobalSymbolCache::~GlobalSymbolCache() = default;
@@ -169,20 +133,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string 
compiler_filter) {
   runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> 
pass_func =
       [compiler_filter = std::move(compiler_filter)](IRModule mod, 
transform::PassContext ctx) {
         IRModule output_mod = mod->ShallowCopy();
-
-        // First pass, rewrite the calls.
-        // We have to do this before marking functions as 'extern' to know 
which calls to rewrite!
-        for (const auto& kv : mod->functions) {
-          if (const auto* function_node = 
AsOptimizableFunctionNode(kv.second)) {
-            Expr new_body =
-                CallRewriter(compiler_filter, 
output_mod).VisitExpr(function_node->body);
-            Function new_function =
-                WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, 
new_body);
-            output_mod->Update(kv.first, new_function);
-          }
-        }
-
-        // Second pass, mark functions as 'extern'.
         for (const auto& kv : mod->functions) {
           if (const auto* function_node = kv.second.as<FunctionNode>()) {
             Optional<String> opt_compiler = 
function_node->GetAttr<String>(attr::kCompiler);
@@ -197,7 +147,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string 
compiler_filter) {
             }
           }
         }
-
         return output_mod;
       };
 
diff --git a/src/relay/transforms/compiler_function_utils.h 
b/src/relay/transforms/compiler_function_utils.h
index e4b1f05211..9d1dcd9f21 100644
--- a/src/relay/transforms/compiler_function_utils.h
+++ b/src/relay/transforms/compiler_function_utils.h
@@ -43,11 +43,8 @@
  *
  *  - \p MarkCompilerFunctionsAsExtern will replace global functions with a 
matching "Compiler"
  *    attribute with the same function with just  an "Extern" attribute, 
signalling the function
- *    has been dealt with. Calls to such functions will be rewritten to use 
the 'call_lowered'
- *    calling convention. Can be used after lowering to cleanup the IRModule.
- *
- * Note that the above behaviour is hard coded within the TECompiler, but is 
only available to
- * external codegen using the Function-at-a-time "relay.ext.toolchain" 
extension point.
+ *    has been dealt with. However calls to such functions will be left 
unchanged.  Can be used
+ *    after lowering to cleanup the IRModule.
  */
 
 #ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_
@@ -118,8 +115,8 @@ transform::Pass 
OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
 
 /*!
  * \brief A pass to mark all global functions which have a "Compiler" 
attribute matching
- * compiler_filter as 'extern' by replacing all attributes with a single 
"Extern" attribute, and
- * rewrite all calls to such functions to use the 'call_lowered' calling 
convention.
+ * compiler_filter as 'extern' by replacing all attributes with a single 
"Extern" attribute.
+ * Calls to such functions are not changed.
  *
  * If \p compiler_filter is non-empty only functions with that as their 
attribute value are
  * outlined.
diff --git a/tests/python/relay/backend/test_pass_lower_te.py 
b/tests/python/relay/backend/test_pass_lower_te.py
new file mode 100644
index 0000000000..310a16e269
--- /dev/null
+++ b/tests/python/relay/backend/test_pass_lower_te.py
@@ -0,0 +1,241 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Exercises the LowerTE pass.
+
+import tvm
+import tvm.testing
+import logging
+
+logging.basicConfig()
+logger = logging.getLogger("test_pass_lower_te")
+logger.setLevel(logging.INFO)
+
+# Since the TE compiler needs a good refactor it has not been exposed as a 
'standard' pass
+# in relay.transform. For testing grab it directly.
+LowerTE = tvm._ffi.get_global_func("relay.tec.LowerTE")
+
+
+def transform(mod):
+    logger.info("Starting module:\n%s", mod)
+    host_target = tvm.target.Target("llvm")
+    prim_target = tvm.target.Target("llvm", host=host_target)
+    ctxt = tvm.transform.PassContext()
+    config = tvm.target.make_compilation_config(ctxt, prim_target)
+    mod = tvm.relay.transform.PlanDevices(config)(mod)
+    mod = tvm.relay.transform.InferType()(mod)
+    mod = LowerTE("test", config)(mod)
+    mod = tvm.relay.transform.InferType()(mod)
+    logger.info("After LowerTE:\n%s", mod)
+    return mod
+
+
+# All attempts to use structural equalty tests against an expected IRModule 
parsed from
+# Relay text were thwarted by the difficulty of setting up the expected 
call_lower attributes
+# with the right GlobalVar instances. So the following assert structural 
correctness the hard way.
+
+
+def test_lower_primitive():
+    input_mod = tvm.parser.parse(
+        """
+        #[version = "0.0.5"]
+        def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+          %0 = fn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], 
Primitive=1) -> Tensor[(5, 7), float32] { 
+            add(%x, %y)
+          };
+          %0(%a, %a)  
+        }      
+        """,
+        "from_string",
+        None,
+        None,
+    )
+
+    actual_mod = transform(input_mod)
+
+    # Expected:
+    #   def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+    #     %0 = (%a, %a);
+    #     call_lowered(@test_fused_add, %0, 
metadata={relay_attrs={Primitive=1},all_prim_fn_vars=[@test_fused_add]})
+    #   }
+    #   def @test_fused_add = <lowered PrimFunc>
+
+    main = actual_mod["main"]
+    call = main.body
+    assert call.op.name == "call_lowered"
+    assert len(call.args) == 2
+    assert call.args[0].name_hint == "test_fused_add"
+    assert len(call.args[1].fields) == 2
+    assert call.args[1].fields[0].name_hint == "a"
+    assert call.args[1].fields[1].name_hint == "a"
+    assert call.attrs.metadata["relay_attrs"].Primitive == 1
+    assert len(call.attrs.metadata["all_prim_fn_vars"]) == 1
+    assert call.attrs.metadata["all_prim_fn_vars"][0].name_hint == 
"test_fused_add"
+
+    test_fused_add = actual_mod["test_fused_add"]
+    assert isinstance(test_fused_add, tvm.tir.PrimFunc)
+
+
+def test_lower_compiler():
+    @tvm._ffi.register_func("relay.ext.test_pass_lower_te")
+    def relay_ext_test_pass_lower_te(func):
+        return None
+
+    input_mod = tvm.parser.parse(
+        """
+        #[version = "0.0.5"]
+        def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+          %0 = fn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], 
Primitive=1, Compiler="test_pass_lower_te", global_symbol="test_add") -> 
Tensor[(5, 7), float32] { 
+            add(%x, %y)
+          };
+          %0(%a, %a)  
+        }      
+        """,
+        "from_string",
+        None,
+        None,
+    )
+
+    actual_mod = transform(input_mod)
+
+    # Expected:
+    #   def @main(%a : Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+    #     %0 = (%a, %a)
+    #     call_lowered(@test_add , %0, metadata={relay_attrs={Primitive=1, 
Compiler="test_pass_lower_te", global_symbol="test_add"}}, all_prim_fn_vars=[]})
+    #   }
+    #   def @test_add(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), 
float32], Extern=1) -> Tensor[(5, 7), float32] {
+    #     add(%x, %y)
+    #   }
+
+    main = actual_mod["main"]
+    call = main.body
+    assert call.op.name == "call_lowered"
+    assert len(call.args) == 2
+    assert call.args[0].name_hint == "test_add"
+    assert len(call.args[1].fields) == 2
+    assert call.args[1].fields[0].name_hint == "a"
+    assert call.args[1].fields[1].name_hint == "a"
+    assert call.attrs.metadata["relay_attrs"].Primitive == 1
+    assert call.attrs.metadata["relay_attrs"].Compiler == "test_pass_lower_te"
+    assert call.attrs.metadata["relay_attrs"].global_symbol == "test_add"
+    assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0
+
+    test_add = actual_mod["test_add"]
+    assert isinstance(test_add, tvm.relay.Function)
+    assert test_add.attrs["Extern"] == 1
+
+
+def test_lower_extern():
+    input_mod = tvm.parser.parse(
+        """
+        #[version = "0.0.5"]
+        def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+          @my_add(%a, %a)
+        }
+        def @my_add(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), 
float32], Extern=1) -> Tensor[(5, 7), float32] { 
+          add(%x, %y)
+        }      
+        """,
+        "from_string",
+        None,
+        None,
+    )
+
+    actual_mod = transform(input_mod)
+
+    # Expected:
+    #   def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+    #     %0 = (%a, %a);
+    #     call_lowered(@my_add, %0, metadata={relay_attrs={Extern=1}}, 
all_prim_fn_vars=[]})
+    #   }
+    #   def @my_add(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], 
Extern=1) -> Tensor[(5, 7), float32] {
+    #     add(%x, %y)
+    #   }
+
+    main = actual_mod["main"]
+    call = main.body
+    assert call.op.name == "call_lowered"
+    assert len(call.args) == 2
+    assert call.args[0].name_hint == "my_add"
+    assert len(call.args[1].fields) == 2
+    assert call.args[1].fields[0].name_hint == "a"
+    assert call.args[1].fields[1].name_hint == "a"
+    assert call.attrs.metadata["relay_attrs"].Extern == 1
+    assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0
+
+    test_add = actual_mod["my_add"]
+    assert isinstance(test_add, tvm.relay.Function)
+    assert test_add.attrs["Extern"] == 1
+
+
+def test_lower_extern_with_dynamic_shape():
+    input_mod = tvm.parser.parse(
+        """
+        #[version = "0.0.5"]
+        def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] {
+          @my_dyn(%a, %a)
+        }
+        def @my_dyn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), 
float32], Extern=1) -> Tensor[(?, ?), float32] { 
+          add(%x, %y)
+        }      
+        """,
+        "from_string",
+        None,
+        None,
+    )
+
+    actual_mod = transform(input_mod)
+
+    # Expected:
+    # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] {
+    #   %0 = (%a, %a);
+    #   call_lowered(@my_dyn, %0, 
metadata={prim_shape_fn_var='shape_func_add', relay_attrs={Extern=1}, 
prim_shape_fn_states=[2, 2], prim_shape_fn_num_inputs=2, 
all_prim_shape_fn_vars=['shape_func_add'], prim_shape_fn_num_outputs=1, 
all_prim_fn_vars=[]})
+    # }
+    # def @my_dyn(%x: Tensor[(5, 7), float32] , %y: Tensor[(5, 7), float32] , 
Extern=1) -> Tensor[(?, ?), float32] {
+    #   add(%x, %y)
+    # }
+    # def @shape_func_add = <shape PrimFunc>
+
+    main = actual_mod["main"]
+    call = main.body
+    assert call.op.name == "call_lowered"
+    assert len(call.args) == 2
+    assert call.args[0].name_hint == "my_dyn"
+    assert len(call.args[1].fields) == 2
+    assert call.args[1].fields[0].name_hint == "a"
+    assert call.args[1].fields[1].name_hint == "a"
+    assert call.attrs.metadata["prim_shape_fn_var"].name_hint == 
"shape_func_add"
+    assert call.attrs.metadata["relay_attrs"].Extern == 1
+    assert len(call.attrs.metadata["prim_shape_fn_states"]) == 2
+    assert call.attrs.metadata["prim_shape_fn_states"][0] == 2
+    assert call.attrs.metadata["prim_shape_fn_states"][1] == 2
+    assert call.attrs.metadata["prim_shape_fn_num_inputs"] == 2
+    assert len(call.attrs.metadata["all_prim_shape_fn_vars"]) == 1
+    assert call.attrs.metadata["all_prim_shape_fn_vars"][0].name_hint == 
"shape_func_add"
+    assert call.attrs.metadata["prim_shape_fn_num_outputs"] == 1
+    assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0
+
+    my_dyn = actual_mod["my_dyn"]
+    assert isinstance(my_dyn, tvm.relay.Function)
+    assert my_dyn.attrs["Extern"] == 1
+
+    shape_func_add = actual_mod["shape_func_add"]
+    assert isinstance(shape_func_add, tvm.tir.PrimFunc)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relay/transform/test_compiler_function_utils.py 
b/tests/python/relay/transform/test_compiler_function_utils.py
index 13e0f98e79..b9eb115475 100644
--- a/tests/python/relay/transform/test_compiler_function_utils.py
+++ b/tests/python/relay/transform/test_compiler_function_utils.py
@@ -38,8 +38,7 @@ metatable = {
             (2304,),  # 1
             (600, 32, 64),  # 2
         ],
-    ),
-    "attributes": [{"relay_attrs": None}],
+    )
 }
 
 
@@ -115,7 +114,7 @@ def expected_extern_mod():
         """
         #[version = "0.0.5"]
         def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 
64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), 
float16]) {
-          %1 = call_lowered(@tvmgen_default_cutlass_main_0, (%x0, 
meta[relay.Constant][0], meta[relay.Constant][1]), 
metadata=meta[attributes][0]);
+          %1 = @tvmgen_default_cutlass_main_0(%x0, meta[relay.Constant][0], 
meta[relay.Constant][1]);
           %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: 
Tensor[(600, 32, 64), float16],
                   Inline=1, Compiler="cublas", 
global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 
32), float16] {
             %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], 
%FunctionVar_0_11: Tensor[(600, 32, 64), float16],

Reply via email to