This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ec24ae60a0 [BYOC] RelayToTIR custom codegen passes can still depend on
dynamic shape functions (#11619)
ec24ae60a0 is described below
commit ec24ae60a028f5aae0fa2f1d8a668eb6bf366414
Author: Mark Shields <[email protected]>
AuthorDate: Thu Jun 9 21:22:49 2022 -0700
[BYOC] RelayToTIR custom codegen passes can still depend on dynamic shape
functions (#11619)
In #11474 I got ready to switch CUTLASS from function-at-a-time to
IRModule-at-a-time compilation.
However my approach didn't handle dynamic shape functions, so I adjust it
here.
The idea is still that such passes will leave behind
calls to 'extern' functions. However, converting those
calls to 'call_lowered' form in
MarkCompilerFunctionsAsExtern is too soon since only
the TECompiler knows how to capture all the attributes
necessary to support dynamic shape functions.
So stop doing that in MarkCompilerFunctionsAsExtern and
instead support this case properly in the TECompiler.
While there try to chip away at the chronic lack of structure in
te_compiler.cc. Every little bit helps.
Add a basic unit test.
---
src/relay/backend/aot_executor_codegen.cc | 8 +-
src/relay/backend/graph_executor_codegen.cc | 27 +-
src/relay/backend/interpreter.cc | 3 +-
src/relay/backend/te_compiler.cc | 329 ++++++++++++++-------
src/relay/backend/te_compiler.h | 32 +-
src/relay/backend/vm/compiler.cc | 24 +-
src/relay/transforms/compiler_function_utils.cc | 51 ----
src/relay/transforms/compiler_function_utils.h | 11 +-
tests/python/relay/backend/test_pass_lower_te.py | 241 +++++++++++++++
.../transform/test_compiler_function_utils.py | 5 +-
10 files changed, 503 insertions(+), 228 deletions(-)
diff --git a/src/relay/backend/aot_executor_codegen.cc
b/src/relay/backend/aot_executor_codegen.cc
index 167afd2c5f..381cfa0c9d 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -1064,9 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
mod = transform::ToANormalForm()(mod);
- IRModule lowered_mod = tec::LowerTEPass(
- mod_name,
- [this, workspace_byte_alignment](BaseFunc func) {
+ IRModule lowered_mod =
+ tec::LowerTE(mod_name, config_, [this,
workspace_byte_alignment](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
@@ -1078,8 +1077,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_,
workspace_byte_alignment);
- },
- config_)(mod);
+ })(mod);
auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
diff --git a/src/relay/backend/graph_executor_codegen.cc
b/src/relay/backend/graph_executor_codegen.cc
index 7dba23803f..af426e5c71 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -217,22 +217,19 @@ class GraphExecutorCodegen : public
backend::MemoizedExprTranslator<std::vector<
mod = WithAttr(mod, "main_func_info", func_info);
}
- IRModule lowered_mod = tec::LowerTEPass(
- mod_name_,
- [this](BaseFunc func) {
- // We need to maintain the constant map for external
- // functions so we pass this processing function which
- // allows us to process each function as we lower it.
- if (func->GetAttr<String>(attr::kCompiler).defined()) {
- UpdateConstants(func, ¶ms_);
- }
+ IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc
func) {
+ // We need to maintain the constant map for external
+ // functions so we pass this processing function which
+ // allows us to process each function as we lower it.
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
+ UpdateConstants(func, ¶ms_);
+ }
- // TODO(@areusch, @jroesch): We should refactor this to
- // execute as a further pass, instead writing data to the
- // lowering process directly.
- tec::UpdateFunctionMetadata(func, this->function_metadata_);
- },
- config_)(mod);
+ // TODO(@areusch, @jroesch): We should refactor this to
+ // execute as a further pass, instead writing data to the
+ // lowering process directly.
+ tec::UpdateFunctionMetadata(func, this->function_metadata_);
+ })(mod);
Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 9661040eab..65a0fdc948 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -960,8 +960,7 @@ IRModule Prepare(IRModule mod, const CompilationConfig&
config) {
// eta expand to support constructors in argument position.
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
- transform::InferType(),
- tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op
*/ }, config)});
+ transform::InferType(), tec::LowerTE(/*module_name=*/"intrp", config)});
transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index c78f3abd6e..e9491b0a89 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -17,6 +17,76 @@
* under the License.
*/
+/*!
+ * \file relay/backend/te_compiler.cc
+ * \brief Manages the transition from Relay "Primitive" \p Functions to TIR \p
PrimFuncs. Also
+ * handles invocation of external codegen.
+ *
+ * \p LowerTEPass handles the following (as a monolithic blob of code):
+ *
+ * - Most importantly, any function with the "Primitive" attribute is first
converted to TE by
+ * \p LowerToTECompute (see te_compiler_cache.cc) using each operator's
'compute' function.
+ * The TE is then 'scheduled' to TIR using the 'anchor' operator's
'schedule' function. Both
+ * of those functions come from the \p OpStrategy returned by the Python
+ * 'relay.backend.lower_call' function (see te_compiler.py).
+ * The TIR is packed as a \p PrimFunc and introduced as a new global
function. Calls to the
+ * original "Primitive" function are then rewritten to the form:
+ * \code
+ * call_lowered(@new_global, (... original args...), attributes)
+ * \endcode
+ *
+ * - The above "Primitive" function can appear:
+ * - As a global function
+ * - As a let-bound function
+ * - As an inline function, ie the 'op' of calls.
+ * In all three cases it is possible for the same "Primitive" function to
be called multiple
+ * times, and that sharing must be respected.
+ *
+ * - "Primitive" functions must have a "global_symbol" attribute matching
their desired or
+ * existing global name. Care is taken to ensure GlobalVars with the same
name are shared.
+ *
+ * - It is possible for multiple structurally equal "Primitive" functions to
appear in the same
+ * \p IRModule. Only one implementation should be generated, and all calls
should share that
+ * implementation.
+ *
+ * - When later converting to DPS (see memory_alloc.cc) we must handle
functions who's result
+ * tensor shapes depend at runtime on the input tensor shapes and/or data.
+ * - That dependency is first described in TE form (see \p MakeShapeFunc in
+ * te_compiler_cache.cc), then scheduled to yield a 'dynamic shape
function' \p PrimFunc.
+ * This relies on each operator's "FShapeFunc" and "TShapeDataDependent"
attributes.
+ * Since shapes are rank-1 tensors everything can be reflected back down
into the regular
+ * TE/TIR forms.
+ * - Then the call_lowered attributes must record everything about the
dynamic shape function
+ * later needed by memory_alloc.cc. We call this 'cross linking' the
call with the shape
+ * function.
+ *
+ * - Two external codegen mechanisms are supported, both triggered by
"Primitive" functions which
+ * also have a "Compiler" attribute bound to $compiler:
+ * - Function-at-a-time (old style): The primitive function is passed to
the function
+ * registered as 'relay.ext.$compiler'. The function returns a
runtime::Module which
+ * should return true for \p ImplementsFunction for the function's
global name. That
+ * module is added to the IRModule's "external_mods" attributes.
+ * - IRModule-at-a-item (new style): The \p RelayToTIRTargetHook sub-pass
looks for
+ * $compiler names which correspond to TargetKind names with a \p
RelayToTIR attribute.
+ * The \p Pass bound to that attribute is run, and each such 'custom'
pass can do what
+ * it likes, including replacing Functions with PrimFuncs, or adding new
runtime::Modules
+ * to the IRModule's "external_mods" attribute.
+ *
+ * - Calls to functions added by external codegen are also rewritten to
call_lowered form, and
+ * may also require cross-linking to dynamic shape functions. However,
since the functions
+ * are/will be implemented by a runtime::Module all the Relay type
information is no longer
+ * available. So the Relay definitions for these "Primitive" "Compiler"
functions are retained
+ * in the \p IRModule, but marked with the "Extern" attribute to signal the
function is now
+ * just for carrying metadata.
+ *
+ * - Some operators are handled specially:
+ * - 'reshape', since it's a no-op on the underlying tensor buffer, and
this is handled by
+ * condition tests in many passes.
+ * - 'debug', since it's intercepted differently depending on runtimes.
+ *
+ * TODO(mbs): This desperately deserves a refactor to separate all these
concerns. See Relax.
+ */
+
#include "./te_compiler.h"
#include <tvm/driver/driver_api.h>
@@ -222,7 +292,7 @@ class TECompilerImpl : public TECompilerNode {
} else {
// It is valid for the external codegen function to return null:
// - Unit tests can use it.
- // - The true compilation may have already been handled by a
RelayToTIR custom hook pass
+ // - The true compilation may have already been handled by a
RelayToTIR custom pass
// on the Target's kind. The original Relay functions will be
left in place so
// that we can capture that their function names are now
externally defined.
VLOG(1) << "Note that no external runtime module was generated by
external codegen '"
@@ -566,100 +636,128 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
return itr->second;
}
} else if (const auto* function_node = expr.as<FunctionNode>()) {
- if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
- // Not marked as primitive by FuseOps.
- return {};
- }
- if (const auto* call_node = function_node->body.as<CallNode>()) {
- if (call_node->op == debug_op_) {
- // Debug 'primitives' are not lowered.
- return {};
+ if (function_node->HasNonzeroAttr(attr::kExtern)) {
+ // We have a regular call to an 'extern' function. The call itself
needs to be rewritten
+ // to call_lowered form, and any required dynamic shape functions
generated and
+ // cross-linked.
+ return GetRef<Function>(function_node);
+ } else if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+ if (const auto* call_node = function_node->body.as<CallNode>()) {
+ if (call_node->op == debug_op_) {
+ // Debug 'primitives' are not lowered.
+ return {};
+ }
}
+ // We have a regular call to a 'primitive' function (possibly with a
'Compiler' attribute).
+ // We need to lower and rewrite the call.
+ return GetRef<Function>(function_node);
+ } else {
+ // Not marked as primitive during partitioning or TVM fusion.
+ return {};
}
- return GetRef<Function>(function_node);
} else {
return {};
}
}
/*!
- * \brief Lowers the primitive function \p func to TIR for ultimate execution
- * on a device with configuration \p target. Returns the global var bound
- * to the TIR implementation, and attributes to attach to the call to
identify it as
- * a TIR call.
+ * \brief Returns a 'call_lowered' call to \p prim_fn_var with \p args and
\p span with all the
+ * required attributes filled in. Generally \p prim_fn_var will correspond
to the lowered or
+ * externally codegen-ed form of \p original_function, where \p
lowered_functions binds all
+ * the required lowered functions.
+ *
+ * The call's attributes will capture:
+ * - Any attributes on the original_function.
+ * - All the lowered functions.
+ * TODO(mbs): Pretty sure that's no longer needed.
+ * - Details needed to cross-link the call to it's dynamic shape function,
if any.
*/
- Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Span span,
Target target) {
- CCacheKey key = CCacheKey(func, target);
- CachedFunc cfunc = compiler_->Lower(key, module_name_);
- ICHECK(cfunc.defined());
-
- auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
+ Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar&
prim_fn_var,
+ Array<Expr> args, Span span, const Target& target,
+ const Map<GlobalVar, BaseFunc>& lowered_functions) {
+ auto opt_compiler = original_function->GetAttr<String>(attr::kCompiler);
// Add some metadata on top of the *original function* and invoke the
callback so it can
// be captured.
// TODO(@areusch, @jroesch): this metadata is for AOT, this should be our
interface for AOT
Map<GlobalVar, tir::PrimFunc> prim_fns;
Array<GlobalVar> all_prim_fn_vars;
- for (const auto& kv : cfunc->funcs->functions) {
+ for (const auto& kv : lowered_functions) {
if (opt_compiler) {
- // We expect just the original func but with just the ExternalSymbol
attribute signaling
- // the function (will be) compiled externally.
+ // We expect the original function to have just the "Extern" attribute
signaling the
+ // function (will be) compiled externally.
ICHECK(kv.second.as<FunctionNode>())
<< PrettyPrint(kv.first) << " must be bound to an (external)
Function";
} else {
- // We expect one or more PrimFuncs, one of which corresponds to 'the'
lowered primitive
- // (and the rest in support of that via tir::Calls).
+ // We expect one or more PrimFuncs, one of which corresponds to 'the'
lowered primitive,
+ // and the rest are in support of that via tir::Calls.
ICHECK(kv.second.as<tir::PrimFuncNode>())
<< PrettyPrint(kv.first) << " must be bound to a PrimFunc";
prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second));
all_prim_fn_vars.push_back(kv.first);
}
}
- Function func_with_metadata = func;
- func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
cfunc->prim_fn_var);
- func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
- func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
cfunc->target);
- this->process_fn_(func_with_metadata);
+ // Alas, WithAttr cannot work with base classes.
+ if (const auto* prim_func_node = original_function.as<te::PrimFuncNode>())
{
+ auto func_with_metadata = GetRef<te::PrimFunc>(prim_func_node);
+ func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
prim_fn_var);
+ func_with_metadata = WithAttr(func_with_metadata, "prim_funcs",
prim_fns);
+ func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
target);
+ this->process_fn_(func_with_metadata);
+ } else {
+ const auto* function_node = original_function.as<FunctionNode>();
+ ICHECK(function_node);
+ auto func_with_metadata = GetRef<Function>(function_node);
+ func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
prim_fn_var);
+ func_with_metadata = WithAttr(func_with_metadata, "prim_funcs",
prim_fns);
+ func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
target);
+ this->process_fn_(func_with_metadata);
+ }
+
+ // Now prepare the attributes of the call_lowered.
CallLoweredAttrs call_lowered_attrs;
- // Non-External Relay Function
// TODO(mbs): "reshape" cleanup.
- if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) {
+ if (!opt_compiler &&
original_function->HasNonzeroAttr(attr::kReshapeOnly)) {
call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
}
- call_lowered_attrs.metadata.Set("relay_attrs", func->attrs);
+ call_lowered_attrs.metadata.Set("relay_attrs", original_function->attrs);
call_lowered_attrs.metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
- if (IsDynamic(func->ret_type)) {
- // Also lower the companion dynamic shape function.
- // Shape function keys use the underlying primitive function as their
'function',
- // but the generic 'cpu' target as the target since all shape functions
run
- // on the host cpu irrespective of where the primitive runs.
- CCacheKey shape_key(func, config_->host_virtual_device->target);
- CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
-
- // Capture the shape function's global var and parameters 'states' in
call
- // annotations so calling convention can be recovered.
- // TODO(mbs): Shape cleanup.
- call_lowered_attrs.metadata.Set("prim_shape_fn_var",
lowered_shape_func->prim_fn_var);
- call_lowered_attrs.metadata.Set("prim_shape_fn_states",
-
lowered_shape_func->shape_func_param_states);
- call_lowered_attrs.metadata.Set("prim_shape_fn_num_inputs",
-
Integer(static_cast<int>(lowered_shape_func->inputs.size())));
- call_lowered_attrs.metadata.Set(
- "prim_shape_fn_num_outputs",
- Integer(static_cast<int>(lowered_shape_func->outputs.size())));
- Array<GlobalVar> all_prim_shape_fn_vars;
- for (const auto& kv : lowered_shape_func->funcs->functions) {
- CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
- all_prim_shape_fn_vars.push_back(kv.first);
+ if (const auto* function_node = original_function.as<FunctionNode>()) {
+ if (IsDynamic(function_node->ret_type)) {
+ // Create a dynamic shape function to calculate the expected shape of
the results of
+ // the lowered function.
+ // Shape function keys use the original function as their 'function',
but the generic 'cpu'
+ // target as the target since all shape functions run on the host cpu
irrespective of where
+ // the primitive runs.
+ CCacheKey shape_key(GetRef<Function>(function_node),
config_->host_virtual_device->target);
+ CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
+
+ // Capture the shape function's global var and parameters 'states' in
call
+ // annotations so calling convention can be recovered.
+ // TODO(mbs): Shape cleanup.
+ call_lowered_attrs.metadata.Set("prim_shape_fn_var",
lowered_shape_func->prim_fn_var);
+ call_lowered_attrs.metadata.Set("prim_shape_fn_states",
+
lowered_shape_func->shape_func_param_states);
+ call_lowered_attrs.metadata.Set(
+ "prim_shape_fn_num_inputs",
+ Integer(static_cast<int>(lowered_shape_func->inputs.size())));
+ call_lowered_attrs.metadata.Set(
+ "prim_shape_fn_num_outputs",
+ Integer(static_cast<int>(lowered_shape_func->outputs.size())));
+ Array<GlobalVar> all_prim_shape_fn_vars;
+ for (const auto& kv : lowered_shape_func->funcs->functions) {
+ CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+ all_prim_shape_fn_vars.push_back(kv.first);
+ }
+ call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars",
all_prim_shape_fn_vars);
}
- call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars",
all_prim_shape_fn_vars);
}
- return CallLowered(cfunc->prim_fn_var, std::move(visited_args),
std::move(call_lowered_attrs),
+ return CallLowered(prim_fn_var, std::move(args),
std::move(call_lowered_attrs),
std::move(span));
}
@@ -697,43 +795,51 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
}
Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
- // We can see five forms of calls:
- // 1. A 'normal' Relay call to a Function with the "primitive" attribute.
We will need
- // to lower that to a global PrimFunc and rewrite the call to:
+ // We can see six forms of calls:
+ // 1. A 'normal' Relay call to a Function with the "Primitive" attribute
and not "Compiler"
+ // attribute. We will need to lower that to a global PrimFunc and
rewrite the call to:
// call_lowered(@new_global, (arg1, ..., argn), <attributes>)
- // However there are a few special forms which are excluded from this
treatment, see
- // below.
- // 2. A 'normal' Relay call to a Function with the "compiler" attribute.
We will need
- // to invoke the appropriate BYOC toolchain function to yield a
runtime module and
- // rewrite the call to the same form as above.
- // 3. A 'normal' Relay call to a PrimFunc which has already been supplied
via a global
- // definition. We rewrite to use the call_lowered form, but otherwise
nothing else
+ // If needed, the call needs to be cross-linked with any dynamic shape
functions.
+ // (However, some primitives are special and handled separately.)
+ // 2. A 'normal' Relay call to a Function with the "Primitive" and
"Compiler" attributes. We
+ // will need to invoke the "relay.ext.<compiler>" function to yield a
runtime module, and
+ // rewrite the call to the same form as above. Dynamic shape function
cross-linking may
+ // also be needed.
+ // 3. A 'normal' Relay call to a Function with the "Extern" attribute.
This function has
+ // already been compiled by an external codegen and a definition for
it exists in some
+ // runtime module. Again, we rewrite to call_lowered form, and
cross-link with a dynamic
+ // shape function if needed.
+ // 4. A 'normal' Relay call to a PrimFunc which has already been supplied
via a global
+ // definition. We rewrite those to use the call_lowered form, but
otherwise nothing else
// needs to be done.
- // 4. A 'normal' Relay call to a Relay Function without any special
attribute. These
+ // 5. A 'call_lowered' call from an earlier invocation of this pass or
otherwise deliberately
+ // inserted. It has all the required attributes, and any associated
dynamic shape function
+ // has been generated and cross-linked. These calls are not changed.
+ // 6. A 'normal' Relay call to a Relay Function without any special
attribute. These
// calls are not changed.
- // 5. A call_lowered call from an earlier invocation of this pass.
- // Note that ResolveToPrimitive will yield non-null only for cases 1-3.
+ //
+ // Note that ResolveToPrimitive will yield non-null only for cases 1-4.
+
+ // Prepare the arguments and op.
+ Array<Expr> new_args;
+ for (const auto& arg : call_node->args) {
+ new_args.push_back(VisitExpr(arg));
+ }
+ Expr new_op = VisitExpr(call_node->op);
// Look for (possibly indirect) calls to primitives.
BaseFunc primitive_func = ResolveToPrimitive(call_node->op);
if (!primitive_func.defined()) {
- // Not a call to a primitive function we need to rewrite.
+ // Cases 5 and 6: Leave as ordinary call.
if (const auto* function_node = call_node->op.as<FunctionNode>()) {
process_fn_(GetRef<Function>(function_node));
}
- return DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
- }
-
- // Prepare the arguments.
- Array<Expr> new_args;
- for (const auto& arg : call_node->args) {
- new_args.push_back(VisitExpr(arg));
+ return WithFields(GetRef<Call>(call_node), std::move(new_op),
std::move(new_args));
}
- // Special case: device_copies are left as calls to primitive operators
- // (thus undoing FuseOps) so that each backend can handle them directly.
- // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just
leave device_copy
- // alone.
+ // Special case for case 1: device_copies are left as calls to primitive
operators
+ // so that each backend can handle them directly.
+ // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just
leave device_copy alone.
if (const auto* function_node = primitive_func.as<FunctionNode>()) {
DeviceCopyProps device_copy_props =
GetDeviceCopyProps(function_node->body);
if (device_copy_props.body.defined()) {
@@ -743,33 +849,23 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
}
}
- // Special case: If already lowered by other means then so we don't need
to mutate
- // the call but we do need to mutate the arguments
+ ICHECK(call_node->type_args.empty()) << "lowered functions cannot be
polymorphic";
+
+ // Case 4: If the function has already been lowered we just need to update
the call.
if (const auto* prim_func_node = primitive_func.as<tir::PrimFuncNode>()) {
// Function should already be Target annotated by this point
// but the TE Compiler metadata is still needed for the callback
// TODO(Mousius) - Robustify this to not assume we're in the GlobalVar
for Target Hooks
- GlobalVar prim_func_var = Downcast<GlobalVar>(call_node->op);
+ Optional<Target> opt_target =
primitive_func->GetAttr<Target>(tvm::attr::kTarget);
+ ICHECK(opt_target.defined());
+ auto prim_fn_var = Downcast<GlobalVar>(call_node->op);
tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
-
- Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var, prim_func}};
- tir::PrimFunc func_with_metadata = WithAttrs(prim_func, {
-
{"prim_fn_var", prim_func_var},
-
{"prim_funcs", prim_fns},
- });
-
- ICHECK(!IsDynamic(call_node->checked_type()));
- CallLoweredAttrs call_lowered_attrs;
- call_lowered_attrs.metadata.Set("relay_attrs", primitive_func->attrs);
-
- process_fn_(func_with_metadata);
- ICHECK(call_node->type_args.empty()) << "lowered functions cannot be
polymorphic";
- return CallLowered(prim_func_var, std::move(new_args),
std::move(call_lowered_attrs),
- call_node->span);
+ Map<GlobalVar, BaseFunc> prim_fns = {{prim_fn_var, prim_func}};
+ return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args),
call_node->span,
+ opt_target.value(), prim_fns);
}
- // Typical case: call to fused primitive Relay Function.
- // Find the desired target device.
+ // Determine the target for lowering or external codegen.
Target target;
Optional<String> opt_compiler =
primitive_func->GetAttr<String>(attr::kCompiler);
if (opt_compiler.defined()) {
@@ -791,10 +887,20 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
ICHECK(target.defined());
}
- // Lower the primitive function for that target.
- Function function = Downcast<Function>(primitive_func);
- ICHECK(call_node->type_args.empty()) << "lowered functions cannot be
polymorphic";
- return MakeLoweredCall(function, std::move(new_args), call_node->span,
target);
+ if (primitive_func->HasNonzeroAttr(attr::kExtern)) {
+ // Case 3: Function has already been compiled.
+ GlobalVar prim_fn_var = Downcast<GlobalVar>(call_node->op);
+ return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args),
call_node->span,
+ target, /*lowered_functions=*/{});
+ } else {
+ // Cases 1 and 2: lower the primitive function for the desired target,
possibly using external
+ // codegen.
+ CCacheKey key(Downcast<Function>(primitive_func), target);
+ CachedFunc cfunc = compiler_->Lower(key, module_name_);
+ ICHECK(cfunc.defined());
+ return MakeLoweredCall(primitive_func, cfunc->prim_fn_var,
std::move(new_args),
+ call_node->span, target, cfunc->funcs->functions);
+ }
}
IRModule module_;
@@ -1046,6 +1152,7 @@ void UpdateFunctionMetadata(BaseFunc func,
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}
+/*! \brief Main lowering driving. */
IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn
process_fn,
CompilationConfig config) {
TECompiler compiler(module);
@@ -1163,7 +1270,7 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
return per_target_modules;
}
-Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig
complilation_config) {
+Pass LowerTE(String module_name, CompilationConfig complilation_config,
ProcessFn process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule module,
PassContext ctx) {
return LowerTE(module, module_name, process_fn, complilation_config);
@@ -1174,6 +1281,12 @@ Pass LowerTEPass(String module_name, ProcessFn
process_fn, CompilationConfig com
tvm::transform::CreateModulePass(pass_func, 0, "LowerTE",
{"InferType"}), InferType(),
tvm::tir::transform::ExtractPrimFuncConstants()});
}
+
+TVM_REGISTER_GLOBAL("relay.tec.LowerTE")
+ .set_body_typed([](String module_name, CompilationConfig
compilation_config) {
+ return LowerTE(std::move(module_name), std::move(compilation_config));
+ });
+
} // namespace tec
} // namespace relay
} // namespace tvm
diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h
index 8312a20cb8..5d16da4b8b 100644
--- a/src/relay/backend/te_compiler.h
+++ b/src/relay/backend/te_compiler.h
@@ -18,8 +18,8 @@
*/
/*!
- * \file relay/backend/tir_compiler.h
- * * \brief Internal compilation layer which lowers Relay "primitive
functions" to TIR PrimFns.
+ * \file relay/backend/te_compiler.h
+ * \brief Internal compilation layer which lowers Relay "primitive functions"
to TIR PrimFns.
*
*
* This represents the new design of the Relay compilation flow and will
replace the interface
@@ -173,36 +173,22 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const
IRModule& mod, const Compila
*/
Map<Target, IRModule> GetPerTargetModules(IRModule mod);
-/*! \brief Lower an IRModule's primitive functions to TIR.
- *
- * This is the "back half" of the Relay compiler which lowers "primitive
functions"
- * to TE expressions, schedules them, and then to TIR.
- *
- * \param module The IRModule.
- * \param memory_plan The memory plan used during lowering
- * \param module_name The name of this module
- * \param process_fn Callback allowing one-level up code generators to process
- * each function that we lower
- * \return The lowered module, see above.
- */
-IRModule LowerTE(
- const IRModule& module, backend::StaticMemoryPlan memory_plan, const
String& module_name,
- ProcessFn process_fn = [](BaseFunc f) {});
+inline void DefaultProcessFn(BaseFunc) {}
/*!
* \brief Pass to lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive
functions"
- * to TE expressions, schedules them, and then to TIR. It annotates all
functions
- * with their target.
+ * to TE expressions, schedules them, and emits PrimFuncs.
*
- * \param module_name The name of this module
- * \param process_fn Callback allowing one-level up code generators to process
- * each function that we lower
+ * \param module_name The name of this module, used as a prefix for generated
globals.
* \param config All available targets.
+ * \param process_fn Callback allowing one-level up code generators to process
+ * each function that we lower (default is no-op).
* \returns The pass which lowers primitive functions to TIR
*/
-transform::Pass LowerTEPass(String module_name, ProcessFn process_fn,
CompilationConfig config);
+transform::Pass LowerTE(String module_name, CompilationConfig config,
+ ProcessFn process_fn = DefaultProcessFn);
} // namespace tec
} // namespace relay
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 48f12ea8aa..8820a403bf 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1040,13 +1040,11 @@ transform::Sequential
VMCompiler::FuseAndLowerOperators(const CompilationConfig&
// Give each "primitive" Function a hash.
pass_seqs.push_back(LabelOps());
// Lower "primitive" Functions to PrimFuncs and rewrite calls.
- pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
- [this](const BaseFunc& func) {
- if
(func->GetAttr<String>(attr::kCompiler).defined()) {
- backend::UpdateConstants(func,
¶ms_);
- }
- },
- config));
+ pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config,
[this](const BaseFunc& func) {
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
+ backend::UpdateConstants(func, ¶ms_);
+ }
+ }));
// Since lowered functions are bound in the IRModule, we can now eliminate
any unused
// let-bound functions.
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
@@ -1091,13 +1089,11 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
pass_seqs.push_back(transform::LabelOps());
// Lower all functions annotated as "primitive" by FuseOps.
- pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
- [this](const BaseFunc& func) {
- if
(func->GetAttr<String>(attr::kCompiler).defined()) {
- backend::UpdateConstants(func,
¶ms_);
- }
- },
- config_));
+ pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config_,
[this](const BaseFunc& func) {
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
+ backend::UpdateConstants(func, ¶ms_);
+ }
+ }));
// Since lowered functions are bound in the IRModule, we can now eliminate
any unused
// let-bound functions.
diff --git a/src/relay/transforms/compiler_function_utils.cc
b/src/relay/transforms/compiler_function_utils.cc
index f22e9bd80d..3df07e4c57 100644
--- a/src/relay/transforms/compiler_function_utils.cc
+++ b/src/relay/transforms/compiler_function_utils.cc
@@ -81,42 +81,6 @@ class Outliner : public MixedModeMutator {
IRModule mod_;
};
-/*!
- * \brief Rewrite calls to global "Compiler" functions to use the
'call_lowered' convention.
- */
-class CallRewriter : public MixedModeMutator {
- public:
- CallRewriter(std::string compiler_filter, IRModule mod)
- : compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {}
-
- Expr Rewrite_(const CallNode* pre, const Expr& post) final {
- Call new_call = Downcast<Call>(post);
- if (const auto* global_var_node = new_call->op.as<GlobalVarNode>()) {
- if (const auto* function_node =
-
mod_->Lookup(GetRef<GlobalVar>(global_var_node)).as<FunctionNode>()) {
- Optional<String> opt_compiler =
function_node->GetAttr<String>(attr::kCompiler);
- if (opt_compiler.defined() &&
- (compiler_filter_.empty() || opt_compiler.value() ==
compiler_filter_)) {
- Optional<String> opt_global_symbol =
- function_node->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(opt_global_symbol.defined());
- GlobalVar global_symbol =
mod_->GetGlobalVar(opt_global_symbol.value());
- CallLoweredAttrs attrs;
- attrs.metadata.Set("relay_attrs", new_call->attrs);
- return CallLowered(global_symbol, new_call->args, attrs,
new_call->span);
- }
- }
- }
- return post;
- }
-
- private:
- /*! \brief If non-empty, the "Compiler" attribute value to require on
functions to outline. */
- std::string compiler_filter_;
- /*! \brief Module being rewritten. */
- IRModule mod_;
-};
-
} // namespace
GlobalSymbolCache::~GlobalSymbolCache() = default;
@@ -169,20 +133,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string
compiler_filter) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)>
pass_func =
[compiler_filter = std::move(compiler_filter)](IRModule mod,
transform::PassContext ctx) {
IRModule output_mod = mod->ShallowCopy();
-
- // First pass, rewrite the calls.
- // We have to do this before marking functions as 'extern' to know
which calls to rewrite!
- for (const auto& kv : mod->functions) {
- if (const auto* function_node =
AsOptimizableFunctionNode(kv.second)) {
- Expr new_body =
- CallRewriter(compiler_filter,
output_mod).VisitExpr(function_node->body);
- Function new_function =
- WithFields(GetRef<Function>(function_node), /*opt_params=*/{},
new_body);
- output_mod->Update(kv.first, new_function);
- }
- }
-
- // Second pass, mark functions as 'extern'.
for (const auto& kv : mod->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
Optional<String> opt_compiler =
function_node->GetAttr<String>(attr::kCompiler);
@@ -197,7 +147,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string
compiler_filter) {
}
}
}
-
return output_mod;
};
diff --git a/src/relay/transforms/compiler_function_utils.h
b/src/relay/transforms/compiler_function_utils.h
index e4b1f05211..9d1dcd9f21 100644
--- a/src/relay/transforms/compiler_function_utils.h
+++ b/src/relay/transforms/compiler_function_utils.h
@@ -43,11 +43,8 @@
*
* - \p MarkCompilerFunctionsAsExtern will replace global functions with a
matching "Compiler"
* attribute with the same function with just an "Extern" attribute,
signalling the function
- * has been dealt with. Calls to such functions will be rewritten to use
the 'call_lowered'
- * calling convention. Can be used after lowering to cleanup the IRModule.
- *
- * Note that the above behaviour is hard coded within the TECompiler, but is
only available to
- * external codegen using the Function-at-a-time "relay.ext.toolchain"
extension point.
+ * has been dealt with. However calls to such functions will be left
unchanged. Can be used
+ * after lowering to cleanup the IRModule.
*/
#ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_
@@ -118,8 +115,8 @@ transform::Pass
OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
/*!
* \brief A pass to mark all global functions which have a "Compiler"
attribute matching
- * compiler_filter as 'extern' by replacing all attributes with a single
"Extern" attribute, and
- * rewrite all calls to such functions to use the 'call_lowered' calling
convention.
+ * compiler_filter as 'extern' by replacing all attributes with a single
"Extern" attribute.
+ * Calls to such functions are not changed.
*
* If \p compiler_filter is non-empty only functions with that as their
attribute value are
* outlined.
diff --git a/tests/python/relay/backend/test_pass_lower_te.py
b/tests/python/relay/backend/test_pass_lower_te.py
new file mode 100644
index 0000000000..310a16e269
--- /dev/null
+++ b/tests/python/relay/backend/test_pass_lower_te.py
@@ -0,0 +1,241 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Exercises the LowerTE pass.
+
+import tvm
+import tvm.testing
+import logging
+
+logging.basicConfig()
+logger = logging.getLogger("test_pass_lower_te")
+logger.setLevel(logging.INFO)
+
+# Since the TE compiler needs a good refactor it has not been exposed as a
'standard' pass
+# in relay.transform. For testing grab it directly.
+LowerTE = tvm._ffi.get_global_func("relay.tec.LowerTE")
+
+
+def transform(mod):
+ logger.info("Starting module:\n%s", mod)
+ host_target = tvm.target.Target("llvm")
+ prim_target = tvm.target.Target("llvm", host=host_target)
+ ctxt = tvm.transform.PassContext()
+ config = tvm.target.make_compilation_config(ctxt, prim_target)
+ mod = tvm.relay.transform.PlanDevices(config)(mod)
+ mod = tvm.relay.transform.InferType()(mod)
+ mod = LowerTE("test", config)(mod)
+ mod = tvm.relay.transform.InferType()(mod)
+ logger.info("After LowerTE:\n%s", mod)
+ return mod
+
+
+# All attempts to use structural equalty tests against an expected IRModule
parsed from
+# Relay text were thwarted by the difficulty of setting up the expected
call_lower attributes
+# with the right GlobalVar instances. So the following assert structural
correctness the hard way.
+
+
+def test_lower_primitive():
+ input_mod = tvm.parser.parse(
+ """
+ #[version = "0.0.5"]
+ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+ %0 = fn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32],
Primitive=1) -> Tensor[(5, 7), float32] {
+ add(%x, %y)
+ };
+ %0(%a, %a)
+ }
+ """,
+ "from_string",
+ None,
+ None,
+ )
+
+ actual_mod = transform(input_mod)
+
+ # Expected:
+ # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+ # %0 = (%a, %a);
+ # call_lowered(@test_fused_add, %0,
metadata={relay_attrs={Primitive=1},all_prim_fn_vars=[@test_fused_add]})
+ # }
+ # def @test_fused_add = <lowered PrimFunc>
+
+ main = actual_mod["main"]
+ call = main.body
+ assert call.op.name == "call_lowered"
+ assert len(call.args) == 2
+ assert call.args[0].name_hint == "test_fused_add"
+ assert len(call.args[1].fields) == 2
+ assert call.args[1].fields[0].name_hint == "a"
+ assert call.args[1].fields[1].name_hint == "a"
+ assert call.attrs.metadata["relay_attrs"].Primitive == 1
+ assert len(call.attrs.metadata["all_prim_fn_vars"]) == 1
+ assert call.attrs.metadata["all_prim_fn_vars"][0].name_hint ==
"test_fused_add"
+
+ test_fused_add = actual_mod["test_fused_add"]
+ assert isinstance(test_fused_add, tvm.tir.PrimFunc)
+
+
+def test_lower_compiler():
+ @tvm._ffi.register_func("relay.ext.test_pass_lower_te")
+ def relay_ext_test_pass_lower_te(func):
+ return None
+
+ input_mod = tvm.parser.parse(
+ """
+ #[version = "0.0.5"]
+ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+ %0 = fn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32],
Primitive=1, Compiler="test_pass_lower_te", global_symbol="test_add") ->
Tensor[(5, 7), float32] {
+ add(%x, %y)
+ };
+ %0(%a, %a)
+ }
+ """,
+ "from_string",
+ None,
+ None,
+ )
+
+ actual_mod = transform(input_mod)
+
+ # Expected:
+ # def @main(%a : Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+ # %0 = (%a, %a)
+ # call_lowered(@test_add , %0, metadata={relay_attrs={Primitive=1,
Compiler="test_pass_lower_te", global_symbol="test_add"}}, all_prim_fn_vars=[]})
+ # }
+ # def @test_add(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7),
float32], Extern=1) -> Tensor[(5, 7), float32] {
+ # add(%x, %y)
+ # }
+
+ main = actual_mod["main"]
+ call = main.body
+ assert call.op.name == "call_lowered"
+ assert len(call.args) == 2
+ assert call.args[0].name_hint == "test_add"
+ assert len(call.args[1].fields) == 2
+ assert call.args[1].fields[0].name_hint == "a"
+ assert call.args[1].fields[1].name_hint == "a"
+ assert call.attrs.metadata["relay_attrs"].Primitive == 1
+ assert call.attrs.metadata["relay_attrs"].Compiler == "test_pass_lower_te"
+ assert call.attrs.metadata["relay_attrs"].global_symbol == "test_add"
+ assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0
+
+ test_add = actual_mod["test_add"]
+ assert isinstance(test_add, tvm.relay.Function)
+ assert test_add.attrs["Extern"] == 1
+
+
+def test_lower_extern():
+ input_mod = tvm.parser.parse(
+ """
+ #[version = "0.0.5"]
+ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+ @my_add(%a, %a)
+ }
+ def @my_add(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7),
float32], Extern=1) -> Tensor[(5, 7), float32] {
+ add(%x, %y)
+ }
+ """,
+ "from_string",
+ None,
+ None,
+ )
+
+ actual_mod = transform(input_mod)
+
+ # Expected:
+ # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
+ # %0 = (%a, %a);
+ # call_lowered(@my_add, %0, metadata={relay_attrs={Extern=1}},
all_prim_fn_vars=[]})
+ # }
+ # def @my_add(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32],
Extern=1) -> Tensor[(5, 7), float32] {
+ # add(%x, %y)
+ # }
+
+ main = actual_mod["main"]
+ call = main.body
+ assert call.op.name == "call_lowered"
+ assert len(call.args) == 2
+ assert call.args[0].name_hint == "my_add"
+ assert len(call.args[1].fields) == 2
+ assert call.args[1].fields[0].name_hint == "a"
+ assert call.args[1].fields[1].name_hint == "a"
+ assert call.attrs.metadata["relay_attrs"].Extern == 1
+ assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0
+
+ test_add = actual_mod["my_add"]
+ assert isinstance(test_add, tvm.relay.Function)
+ assert test_add.attrs["Extern"] == 1
+
+
+def test_lower_extern_with_dynamic_shape():
+ input_mod = tvm.parser.parse(
+ """
+ #[version = "0.0.5"]
+ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] {
+ @my_dyn(%a, %a)
+ }
+ def @my_dyn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7),
float32], Extern=1) -> Tensor[(?, ?), float32] {
+ add(%x, %y)
+ }
+ """,
+ "from_string",
+ None,
+ None,
+ )
+
+ actual_mod = transform(input_mod)
+
+ # Expected:
+ # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] {
+ # %0 = (%a, %a);
+ # call_lowered(@my_dyn, %0,
metadata={prim_shape_fn_var='shape_func_add', relay_attrs={Extern=1},
prim_shape_fn_states=[2, 2], prim_shape_fn_num_inputs=2,
all_prim_shape_fn_vars=['shape_func_add'], prim_shape_fn_num_outputs=1,
all_prim_fn_vars=[]})
+ # }
+ # def @my_dyn(%x: Tensor[(5, 7), float32] , %y: Tensor[(5, 7), float32] ,
Extern=1) -> Tensor[(?, ?), float32] {
+ # add(%x, %y)
+ # }
+ # def @shape_func_add = <shape PrimFunc>
+
+ main = actual_mod["main"]
+ call = main.body
+ assert call.op.name == "call_lowered"
+ assert len(call.args) == 2
+ assert call.args[0].name_hint == "my_dyn"
+ assert len(call.args[1].fields) == 2
+ assert call.args[1].fields[0].name_hint == "a"
+ assert call.args[1].fields[1].name_hint == "a"
+ assert call.attrs.metadata["prim_shape_fn_var"].name_hint ==
"shape_func_add"
+ assert call.attrs.metadata["relay_attrs"].Extern == 1
+ assert len(call.attrs.metadata["prim_shape_fn_states"]) == 2
+ assert call.attrs.metadata["prim_shape_fn_states"][0] == 2
+ assert call.attrs.metadata["prim_shape_fn_states"][1] == 2
+ assert call.attrs.metadata["prim_shape_fn_num_inputs"] == 2
+ assert len(call.attrs.metadata["all_prim_shape_fn_vars"]) == 1
+ assert call.attrs.metadata["all_prim_shape_fn_vars"][0].name_hint ==
"shape_func_add"
+ assert call.attrs.metadata["prim_shape_fn_num_outputs"] == 1
+ assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0
+
+ my_dyn = actual_mod["my_dyn"]
+ assert isinstance(my_dyn, tvm.relay.Function)
+ assert my_dyn.attrs["Extern"] == 1
+
+ shape_func_add = actual_mod["shape_func_add"]
+ assert isinstance(shape_func_add, tvm.tir.PrimFunc)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relay/transform/test_compiler_function_utils.py
b/tests/python/relay/transform/test_compiler_function_utils.py
index 13e0f98e79..b9eb115475 100644
--- a/tests/python/relay/transform/test_compiler_function_utils.py
+++ b/tests/python/relay/transform/test_compiler_function_utils.py
@@ -38,8 +38,7 @@ metatable = {
(2304,), # 1
(600, 32, 64), # 2
],
- ),
- "attributes": [{"relay_attrs": None}],
+ )
}
@@ -115,7 +114,7 @@ def expected_extern_mod():
"""
#[version = "0.0.5"]
def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32,
64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32),
float16]) {
- %1 = call_lowered(@tvmgen_default_cutlass_main_0, (%x0,
meta[relay.Constant][0], meta[relay.Constant][1]),
metadata=meta[attributes][0]);
+ %1 = @tvmgen_default_cutlass_main_0(%x0, meta[relay.Constant][0],
meta[relay.Constant][1]);
%2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1:
Tensor[(600, 32, 64), float16],
Inline=1, Compiler="cublas",
global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32,
32), float16] {
%6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16],
%FunctionVar_0_11: Tensor[(600, 32, 64), float16],