areusch commented on code in PR #11619:
URL: https://github.com/apache/tvm/pull/11619#discussion_r892630449
##########
src/relay/backend/te_compiler.cc:
##########
@@ -697,43 +725,51 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
}
Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
- // We can see five forms of calls:
- // 1. A 'normal' Relay call to a Function with the "primitive" attribute.
We will need
- // to lower that to a global PrimFunc and rewrite the call to:
+ // We can see six forms of calls:
+ // 1. A 'normal' Relay call to a Function with the "Primitive" attribute
and not "Compiler"
+ // attribute. We will need to lower that to a global PrimFunc and
rewrite the call to:
// call_lowered(@new_global, (arg1, ..., argn), <attributes>)
- // However there are a few special forms which are excluded from this
treatment, see
- // below.
- // 2. A 'normal' Relay call to a Function with the "compiler" attribute.
We will need
- // to invoke the appropriate BYOC toolchain function to yield a
runtime module and
- // rewrite the call to the same form as above.
- // 3. A 'normal' Relay call to a PrimFunc which has already been supplied
via a global
- // definition. We rewrite to use the call_lowered form, but otherwise
nothing else
+ // If needed, the call needs to be cross-linked with any dynamic shape
functions.
Review Comment:
what do you mean by cross-linked here?
##########
src/relay/backend/te_compiler.cc:
##########
@@ -566,100 +566,128 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
return itr->second;
}
} else if (const auto* function_node = expr.as<FunctionNode>()) {
- if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
- // Not marked as primitive by FuseOps.
- return {};
- }
- if (const auto* call_node = function_node->body.as<CallNode>()) {
- if (call_node->op == debug_op_) {
- // Debug 'primitives' are not lowered.
- return {};
+ if (function_node->HasNonzeroAttr(attr::kExtern)) {
Review Comment:
does this functionality need test?
##########
src/relay/backend/te_compiler.cc:
##########
@@ -566,100 +566,128 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
return itr->second;
}
} else if (const auto* function_node = expr.as<FunctionNode>()) {
- if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
- // Not marked as primitive by FuseOps.
- return {};
- }
- if (const auto* call_node = function_node->body.as<CallNode>()) {
- if (call_node->op == debug_op_) {
- // Debug 'primitives' are not lowered.
- return {};
+ if (function_node->HasNonzeroAttr(attr::kExtern)) {
+ // We have a regular call to an 'extern' function. The call itself
needs to be rewritten
+ // to call_lowered form, and any required dynamic shape functions
generated and
+ // cross-linked.
+ return GetRef<Function>(function_node);
+ } else if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+ if (const auto* call_node = function_node->body.as<CallNode>()) {
+ if (call_node->op == debug_op_) {
+ // Debug 'primitives' are not lowered.
+ return {};
+ }
}
+ // We have a regular call to a 'primitive' function (possibly with a
'Compiler' attribute).
+ // We need to lower and rewrite the call.
+ return GetRef<Function>(function_node);
+ } else {
+ // Not marked as primitive during partitioning or TVM fusion.
+ return {};
}
- return GetRef<Function>(function_node);
} else {
return {};
}
}
/*!
- * \brief Lowers the primitive function \p func to TIR for ultimate execution
- * on a device with configuration \p target. Returns the global var bound
- * to the TIR implementation, and attributes to attach to the call to
identify it as
- * a TIR call.
+ * \brief Returns a 'call_lowered' call to \p prim_fn_var with \p args and
\p span with all the
+ * required attributes filled in. Generally \p prim_fn_var will correspond
to the lowered or
+ * externally codegen-ed form of \p original_function, where \p
lowered_functions binds all
+ * the required lowered functions.
+ *
+ * The call's attributes will capture:
+ * - Any attributes on the original_function.
+ * - All the lowered functions.
+ * TODO(mbs): Pretty sure that's no longer needed.
+ * - Details needed to cross-link the call to it's dynamic shape function,
if any.
*/
- Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Span span,
Target target) {
- CCacheKey key = CCacheKey(func, target);
- CachedFunc cfunc = compiler_->Lower(key, module_name_);
- ICHECK(cfunc.defined());
-
- auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
+ Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar&
prim_fn_var,
+ Array<Expr> args, Span span, const Target& target,
+ const Map<GlobalVar, BaseFunc>& lowered_functions) {
+ auto opt_compiler = original_function->GetAttr<String>(attr::kCompiler);
// Add some metadata on top of the *original function* and invoke the
callback so it can
// be captured.
// TODO(@areusch, @jroesch): this metadata is for AOT, this should be our
interface for AOT
Map<GlobalVar, tir::PrimFunc> prim_fns;
Array<GlobalVar> all_prim_fn_vars;
- for (const auto& kv : cfunc->funcs->functions) {
+ for (const auto& kv : lowered_functions) {
if (opt_compiler) {
- // We expect just the original func but with just the ExternalSymbol
attribute signaling
- // the function (will be) compiled externally.
+ // We expect the original function to have just the "Extern" attribute
signaling the
+ // function (will be) compiled externally.
ICHECK(kv.second.as<FunctionNode>())
<< PrettyPrint(kv.first) << " must be bound to an (external)
Function";
} else {
- // We expect one or more PrimFuncs, one of which corresponds to 'the'
lowered primitive
- // (and the rest in support of that via tir::Calls).
+ // We expect one or more PrimFuncs, one of which corresponds to 'the'
lowered primitive,
+ // and the rest are in support of that via tir::Calls.
ICHECK(kv.second.as<tir::PrimFuncNode>())
<< PrettyPrint(kv.first) << " must be bound to a PrimFunc";
prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second));
all_prim_fn_vars.push_back(kv.first);
}
}
- Function func_with_metadata = func;
- func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
cfunc->prim_fn_var);
- func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
- func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
cfunc->target);
- this->process_fn_(func_with_metadata);
+ // Alas, WithAttr cannot work with base classes.
Review Comment:
i wanted to see if a templated lambda would work here, but alas that's in
C++20
##########
src/relay/transforms/compiler_function_utils.cc:
##########
@@ -167,20 +131,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string
compiler_filter) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)>
pass_func =
[compiler_filter = std::move(compiler_filter)](IRModule mod,
transform::PassContext ctx) {
IRModule output_mod = mod->ShallowCopy();
-
- // First pass, rewrite the calls.
- // We have to do this before marking functions as 'extern' to know
which calls to rewrite!
Review Comment:
checking my understanding: this is no longer needed because
LowerTensorExprMutator handles it?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]