electriclilies commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759706150
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -481,206 +557,217 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
*/
Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type>
type_args, Span span,
Target target) {
- if (func->GetAttr<String>(attr::kCompiler).defined()) {
- // BYOC flow.
- CCacheKey key = CCacheKey(func, target);
- CachedFunc ext_func = compiler_->Lower(key, module_name_);
- ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
- << ext_func->prim_fn_var->name_hint;
-
- // TODO(@areusch, @jroesch): this metadata is for AOT, this should be
our interface for AOT
- Map<GlobalVar, tir::PrimFunc> prim_fns;
- relay::Function func_with_metadata = func;
- func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
ext_func->prim_fn_var);
- func_with_metadata = WithAttr(func_with_metadata, "prim_funcs",
prim_fns);
- func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
ext_func->target);
-
- // Provide a callback hook which allows one-level up code generators to
- // act when we process a function.
- this->process_fn_(func_with_metadata);
-
- // TODO(mbs): Dynamic shapes?
- // TODO(@mbs, electriclilies): Make extern functions explicit
- return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args,
span);
-
- } else {
- // Non-External Relay Function
- CCacheKey key = CCacheKey(func, target);
- CachedFunc lowered_func = compiler_->Lower(key, module_name_);
-
- // Collect all the lowered functions produced for this primitive
function.
- Map<GlobalVar, tir::PrimFunc> prim_fns;
- Array<GlobalVar> all_prim_fn_vars;
- for (auto prim_fn : lowered_func->funcs->functions) {
- CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
- prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
- all_prim_fn_vars.push_back(prim_fn.first);
- }
-
- // TODO(@areusch, @jroesch): this metadata is for AOT, this should be
our interface for AOT
- relay::Function func_with_metadata = func;
- func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
lowered_func->prim_fn_var);
- func_with_metadata = WithAttr(func_with_metadata, "prim_funcs",
prim_fns);
- func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
lowered_func->target);
-
- // Provide a callback hook which allows one-level up code generators to
- // act when we process a function.
- this->process_fn_(func_with_metadata);
-
- auto call_lowered_attrs = make_object<CallLoweredAttrs>();
- if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
- call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
- }
-
- DeviceCopyProps props = GetDeviceCopyProps(func);
- if (props.body.defined()) {
- // Record the device copy source and destination scopes so the device
planner can
- // still follow along even after lowering.
- call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope);
- call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope);
+ CCacheKey key = CCacheKey(func, target);
+ CachedFunc cfunc = compiler_->Lower(key, module_name_);
+ ICHECK(cfunc.defined());
+
+ auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
+
+ // Add some metadata on top of the *original function* and invoke the
callback so it can
+ // be captured.
+ // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our
interface for AOT
+ Map<GlobalVar, tir::PrimFunc> prim_fns;
+ Array<GlobalVar> all_prim_fn_vars;
+ for (const auto& kv : cfunc->funcs->functions) {
+ if (opt_compiler) {
+ // We expect just the original func but with just the ExternalSymbol
attribute signalling
Review comment:
typo: signalling -> signaling
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -951,22 +1039,87 @@ void UpdateFunctionMetadata(BaseFunc func,
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}
-IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn
process_fn) {
- TECompiler compiler;
-
- auto updated_module = LowerTensorExpr(module_name, compiler,
process_fn)(module);
-
- backend::UpdateAutoSchedulerOpWeights(compiler);
+IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn
process_fn,
+ SEScope host_se_scope) {
+ TECompiler compiler(module);
+
+ // TODO(mbs): This is all unnecessarily convoluted. Better would be to
accumulate the rewritten
+ // module as we go (including rewritten Functions, lowered primitives, and
runtime modules
+ // generated by external toolchains), and use a pair of maps over vars and
global vars
+ // to global vars to remember which functions have already been lowered.
+
+ // Lower all the callees in module:
+ // - Functions tagged with "Compiler" are unchanged (checked by
CreateFunctionPass)
+ // - Functions tagged with "Primitive" are unchanged (checked by
LowerTensorExprMutator)
+ // - Called functions tagged with "Compiler" are copied into the compiler
cache with a fresh
+ // GlobalVar, and calls updated (sticking with regular Relay Call).
+ // - Calls to functions tagged with "Primitive" are compiled to PrimFuncs,
and calls updated
+ // (using call_lowered convention).
+ IRModule updated_module = LowerTensorExpr(module_name, compiler,
std::move(process_fn),
+ std::move(host_se_scope))(module);
+
+ // The Functions tagged with "Compiler" are now residing in the cache ready
to be
+ // compiled by LowerExternalFunctions. However we still need a record of
them in the
+ // IRModule so that the various executors can see which function names need
to be
+ // retrieved. They may, however, have been renamed.
+ compiler->AddExterns(updated_module);
+
+ // Add the lowered functions.
+ IRModule lowered_module = compiler->GetLoweredFunctions();
+ VLOG(1) << "capturing " << lowered_module->functions.size() << " new lowered
functions";
+ for (const auto& kv : lowered_module->functions) {
+ if (updated_module->ContainGlobalVar(kv.first->name_hint)) {
+ LOG(FATAL) << "duplicate bindings for '" << kv.first->name_hint
+ << "'. Existing is:" << std::endl
+ << PrettyPrint(updated_module->Lookup(kv.first->name_hint))
<< std::endl
+ << "while new is:" << std::endl
+ << PrettyPrint(kv.second);
+ }
+ updated_module->Add(kv.first, kv.second);
+ }
- // Copy the lowered functions into the return module
- updated_module->Update(compiler->GetLoweredFunctions());
+ // Invoke external codegen for all Functions in the cache tagged with
"Compiler", and
+ // annotate the module with the resulting runtime modules.
+ // TODO(mbs): runtime modules should be first class rather than attributes.
+ Array<runtime::Module> external_mods =
+ module->GetAttr<Array<runtime::Module>>("external_mods",
Array<runtime::Module>()).value();
+ Array<runtime::Module> new_external_mods =
compiler->LowerExternalFunctions();
+ VLOG(1) << "capturing " << external_mods.size() << " existing and " <<
new_external_mods.size()
+ << " new external modules";
+ for (const auto& mod : new_external_mods) {
+ external_mods.push_back(mod); // copy-on-write.
+ }
- // Annotate the module with C Device API context mapping, the external
modules and function info
- // this is until we have Target's annotated for the C Device API
+ // Annotate the module with C Device API context mapping (this is until we
have Target's
+ // annotated for the C Device API)
// TODO(Mousius) - Remove "device_contexts" as soon as we have the graph
annotated properly with
// Target's
Review comment:
again Target's should be Targets or targets
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -951,22 +1039,87 @@ void UpdateFunctionMetadata(BaseFunc func,
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}
-IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn
process_fn) {
- TECompiler compiler;
-
- auto updated_module = LowerTensorExpr(module_name, compiler,
process_fn)(module);
-
- backend::UpdateAutoSchedulerOpWeights(compiler);
+IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn
process_fn,
+ SEScope host_se_scope) {
+ TECompiler compiler(module);
+
+ // TODO(mbs): This is all unnecessarily convoluted. Better would be to
accumulate the rewritten
+ // module as we go (including rewritten Functions, lowered primitives, and
runtime modules
+ // generated by external toolchains), and use a pair of maps over vars and
global vars
+ // to global vars to remember which functions have already been lowered.
+
+ // Lower all the callees in module:
+ // - Functions tagged with "Compiler" are unchanged (checked by
CreateFunctionPass)
+ // - Functions tagged with "Primitive" are unchanged (checked by
LowerTensorExprMutator)
+ // - Called functions tagged with "Compiler" are copied into the compiler
cache with a fresh
+ // GlobalVar, and calls updated (sticking with regular Relay Call).
+ // - Calls to functions tagged with "Primitive" are compiled to PrimFuncs,
and calls updated
+ // (using call_lowered convention).
+ IRModule updated_module = LowerTensorExpr(module_name, compiler,
std::move(process_fn),
+ std::move(host_se_scope))(module);
+
+ // The Functions tagged with "Compiler" are now residing in the cache ready
to be
+ // compiled by LowerExternalFunctions. However we still need a record of
them in the
+ // IRModule so that the various executors can see which function names need
to be
+ // retrieved. They may, however, have been renamed.
+ compiler->AddExterns(updated_module);
+
+ // Add the lowered functions.
+ IRModule lowered_module = compiler->GetLoweredFunctions();
+ VLOG(1) << "capturing " << lowered_module->functions.size() << " new lowered
functions";
+ for (const auto& kv : lowered_module->functions) {
+ if (updated_module->ContainGlobalVar(kv.first->name_hint)) {
+ LOG(FATAL) << "duplicate bindings for '" << kv.first->name_hint
+ << "'. Existing is:" << std::endl
+ << PrettyPrint(updated_module->Lookup(kv.first->name_hint))
<< std::endl
+ << "while new is:" << std::endl
+ << PrettyPrint(kv.second);
+ }
+ updated_module->Add(kv.first, kv.second);
+ }
- // Copy the lowered functions into the return module
- updated_module->Update(compiler->GetLoweredFunctions());
+ // Invoke external codegen for all Functions in the cache tagged with
"Compiler", and
+ // annotate the module with the resulting runtime modules.
+ // TODO(mbs): runtime modules should be first class rather than attributes.
+ Array<runtime::Module> external_mods =
+ module->GetAttr<Array<runtime::Module>>("external_mods",
Array<runtime::Module>()).value();
+ Array<runtime::Module> new_external_mods =
compiler->LowerExternalFunctions();
+ VLOG(1) << "capturing " << external_mods.size() << " existing and " <<
new_external_mods.size()
+ << " new external modules";
+ for (const auto& mod : new_external_mods) {
+ external_mods.push_back(mod); // copy-on-write.
+ }
- // Annotate the module with C Device API context mapping, the external
modules and function info
- // this is until we have Target's annotated for the C Device API
+ // Annotate the module with C Device API context mapping (this is until we
have Target's
Review comment:
typo: Target's -> Targets
##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -427,26 +414,49 @@ class MakeShapeFunc : public
backend::MemoizedExprTranslator<Array<te::Tensor>>
candidate_name = truncated_name.str();
}
- // Set all the inputs correctly.
+ // Set all the inputs correctly, and accumulate their types from the
p.o.v. of the
+ // shape function rather than the primitive it is derived for.
Array<te::Tensor> inputs;
+ Array<Type> shape_function_arg_types;
for (auto param : prim_func->params) {
int state = param_states_[param];
shape_func_param_states.push_back(IntImm(DataType::Int(32), state));
if (state & kNeedInputData) {
+ // Pass the primitive arguments directly (though in flattened form and
on the host)
for (auto t : param_data_[param]) {
inputs.push_back(t);
+ shape_function_arg_types.push_back(TensorType(t->GetShape(),
t->GetDataType()));
}
}
if (state & kNeedInputShape) {
+ // Pass the shapes of the primitive arguments (also on the host)
for (auto t : param_shapes_[param]) {
inputs.push_back(t);
+ shape_function_arg_types.push_back(TensorType(t->GetShape(),
t->GetDataType()));
}
}
}
+ // TODO(mbs): This should be the definitive global by which the PrimFunc
is known and
+ // no other GlobalVar ctors should appear inside the lowering machinery.
auto func_name = renamer(candidate_name);
auto prim_fn_gvar = GlobalVar(func_name);
- prim_fn_gvar->checked_type_ = prim_func->checked_type();
+
+ // Gather the result types, again from the p.o.v. of the shape function
rather than
+ // the primitive it is derived for.
+ Array<Type> shape_function_res_types;
+ for (const auto& t : outputs) {
+ shape_function_res_types.push_back(TensorType(t->GetShape(),
t->GetDataType()));
+ }
+
+ // Assign the shape function it's true type.
Review comment:
typo: it's -> its
##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -471,10 +481,17 @@ class MakeShapeFunc : public
backend::MemoizedExprTranslator<Array<te::Tensor>>
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
std::unordered_map<te::Tensor, tir::Buffer> binds;
- IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name,
binds);
-
+ IRModule lowered_module = tvm::LowerSchedule(schedule, all_args,
func_name, binds);
+
+ IRModule fixed_lowered_module;
+ for (const auto& kv : lowered_module->functions) {
Review comment:
Why do we need to fix the lowered module?
##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -1051,6 +995,26 @@ transform::Sequential MemoryOpt(const SEScope&
cpu_se_scope) {
return transform::Sequential(std::move(pass_seqs));
}
+transform::Sequential VMCompiler::LowerOperators(const SEScope& host_se_scope)
{
+ Array<Pass> pass_seqs;
Review comment:
Might be good to rename this to fuse and lower (or something that
mentions fusion), in the places you used it you also have a note that it fuses
as well as lowers.
##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -481,50 +484,12 @@ class VMFunctionCompiler :
DeviceAwareExprFunctor<void(const Expr& n)> {
this->last_register_ = merge_register;
}
- void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
Review comment:
Cool that you got rid of this!!
##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -1102,7 +1066,24 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
pass_seqs.push_back(transform::ToANormalForm());
pass_seqs.push_back(transform::InferType());
pass_seqs.push_back(transform::LambdaLift());
- pass_seqs.push_back(transform::InlinePrimitives());
+
+ // Eliminate dead-code before we lower. We don't track the purity of
PrimFuncs, thus after
+ // lowering all calls to lowered functions will be kept.
+ pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
+ pass_seqs.push_back(transform::LabelOps());
+
+ // Lower all function's annotated as "primitive" by FuseOps.
Review comment:
typo: function's -> functions
##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -395,7 +400,7 @@ class TypeInferencer : private ExprFunctor<Type(const
Expr&)>,
//
// The result will be the return type of the operator.
Type PrimitiveCall(const FuncTypeNode* op, Array<Type> arg_types, const
Attrs& attrs,
- const Span& span) {
+ const Span& span, const Expr& expr) {
Review comment:
I think you don't use expr in this function at all? Also can you change
op to func_type_node :-)
##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -765,7 +765,8 @@ class DeviceCapturer : public ExprMutator {
IRModule Capture() {
VLOG_CONTEXT << "CaptureDevices";
- IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(),
mod_->source_map);
+ IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(),
mod_->source_map,
Review comment:
Hmm, maybe we should add a WithFields for IRModule.
##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -481,50 +484,12 @@ class VMFunctionCompiler :
DeviceAwareExprFunctor<void(const Expr& n)> {
this->last_register_ = merge_register;
}
- void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
- // Lower shape function
- tec::CCacheKey key(func, host_se_scope_->target);
- auto cfunc = context_->compiler->LowerShapeFunc(key);
- int op_index = -1;
- // pick the only function inside the context
- ICHECK_EQ(cfunc->funcs->functions.size(), 1);
- auto pfunc =
Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
- if (context_->seen_funcs.count(pfunc) == 0) {
- op_index = context_->cached_funcs.size();
- context_->cached_funcs.push_back(cfunc);
- context_->seen_funcs[pfunc] = op_index;
- } else {
- op_index = context_->seen_funcs[pfunc];
- }
-
- // Prepare input and output registers
+ void EmitInvokeTVMOp(const Expr& func, const Expr& inputs, const Expr&
outputs,
Review comment:
Why don't we need to pass SEScope in here anymore?
##########
File path: src/relay/op/memory/device_copy.cc
##########
@@ -117,15 +117,5 @@ DeviceCopyProps GetDeviceCopyProps(const Expr& expr) {
return {};
}
-DeviceCopyProps GetLoweredDeviceCopyProps(const CallLoweredProps& props) {
Review comment:
glad to see this simplification!
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -481,206 +557,217 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
*/
Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type>
type_args, Span span,
Target target) {
- if (func->GetAttr<String>(attr::kCompiler).defined()) {
- // BYOC flow.
- CCacheKey key = CCacheKey(func, target);
- CachedFunc ext_func = compiler_->Lower(key, module_name_);
- ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
- << ext_func->prim_fn_var->name_hint;
-
- // TODO(@areusch, @jroesch): this metadata is for AOT, this should be
our interface for AOT
- Map<GlobalVar, tir::PrimFunc> prim_fns;
- relay::Function func_with_metadata = func;
- func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
ext_func->prim_fn_var);
- func_with_metadata = WithAttr(func_with_metadata, "prim_funcs",
prim_fns);
- func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
ext_func->target);
-
- // Provide a callback hook which allows one-level up code generators to
- // act when we process a function.
- this->process_fn_(func_with_metadata);
-
- // TODO(mbs): Dynamic shapes?
- // TODO(@mbs, electriclilies): Make extern functions explicit
- return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args,
span);
-
- } else {
- // Non-External Relay Function
- CCacheKey key = CCacheKey(func, target);
- CachedFunc lowered_func = compiler_->Lower(key, module_name_);
-
- // Collect all the lowered functions produced for this primitive
function.
- Map<GlobalVar, tir::PrimFunc> prim_fns;
- Array<GlobalVar> all_prim_fn_vars;
- for (auto prim_fn : lowered_func->funcs->functions) {
- CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
- prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
- all_prim_fn_vars.push_back(prim_fn.first);
- }
-
- // TODO(@areusch, @jroesch): this metadata is for AOT, this should be
our interface for AOT
- relay::Function func_with_metadata = func;
- func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
lowered_func->prim_fn_var);
- func_with_metadata = WithAttr(func_with_metadata, "prim_funcs",
prim_fns);
- func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
lowered_func->target);
-
- // Provide a callback hook which allows one-level up code generators to
- // act when we process a function.
- this->process_fn_(func_with_metadata);
-
- auto call_lowered_attrs = make_object<CallLoweredAttrs>();
- if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
- call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
- }
-
- DeviceCopyProps props = GetDeviceCopyProps(func);
- if (props.body.defined()) {
- // Record the device copy source and destination scopes so the device
planner can
- // still follow along even after lowering.
- call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope);
- call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope);
+ CCacheKey key = CCacheKey(func, target);
+ CachedFunc cfunc = compiler_->Lower(key, module_name_);
+ ICHECK(cfunc.defined());
+
+ auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
+
+ // Add some metadata on top of the *original function* and invoke the
callback so it can
+ // be captured.
+ // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our
interface for AOT
+ Map<GlobalVar, tir::PrimFunc> prim_fns;
+ Array<GlobalVar> all_prim_fn_vars;
+ for (const auto& kv : cfunc->funcs->functions) {
+ if (opt_compiler) {
+ // We expect just the original func but with just the ExternalSymbol
attribute signalling
+ // the function (will be) compiled externally.
+ ICHECK(kv.second.as<FunctionNode>())
+ << PrettyPrint(kv.first) << " must be bound to an (external)
Function";
+ } else {
+ // We expect one or more PrimFuncs, one of which corresponds to 'the'
lowered primitive
+ // (and the rest in support of that via tir::Calls).
+ ICHECK(kv.second.as<tir::PrimFuncNode>())
+ << PrettyPrint(kv.first) << " must be bound to a PrimFunc";
+ prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second));
+ all_prim_fn_vars.push_back(kv.first);
}
+ }
+ Function func_with_metadata = func;
+ func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
cfunc->prim_fn_var);
+ func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
+ func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
cfunc->target);
+ this->process_fn_(func_with_metadata);
+
+ auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+
+ // Non-External Relay Function
+ // TODO(mbs): "reshape" cleanup.
+ if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) {
+ call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
+ }
- call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
- call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
-
- if (IsDynamic(func->ret_type)) {
- // Also lower the dynamic shape function.
- // Shape function keys use the underlying primitive function as their
'function',
- // but the generic 'cpu' target as the target since all shape
functions run
- // on the host cpu irrespective of where the primitive runs.
- // TODO(mbs): Cleanup target handling.
- Target shape_target("llvm");
- CCacheKey shape_key(func, shape_target);
- CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
- // Capture the shape function's global var and parameters 'states' in
call
- // annotations so calling convention can be recovered.
- // TODO(mbs): Capture all this as part of a 'call into TIR' construct
once available.
- // The way the shape function calling convention is derived and passed
to call sites
- // via the 'parameter states' could be improved.
- call_lowered_attrs->metadata.Set("prim_shape_fn_var",
lowered_shape_func->prim_fn_var);
- call_lowered_attrs->metadata.Set("prim_shape_fn_states",
-
lowered_shape_func->shape_func_param_states);
- call_lowered_attrs->metadata.Set(
- "prim_shape_fn_num_inputs",
- Integer(static_cast<int>(lowered_shape_func->inputs.size())));
- call_lowered_attrs->metadata.Set(
- "prim_shape_fn_num_outputs",
- Integer(static_cast<int>(lowered_shape_func->outputs.size())));
- Array<GlobalVar> all_prim_shape_fn_vars;
- for (auto prim_shape_fn : lowered_shape_func->funcs->functions) {
- CHECK(prim_shape_fn.second.as<tir::PrimFuncNode>()) << "must be a
prim fn";
- all_prim_shape_fn_vars.push_back(prim_shape_fn.first);
- }
- call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars",
all_prim_shape_fn_vars);
+ call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
+ call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
+
+ if (IsDynamic(func->ret_type)) {
+ // Also lower the companion dynamic shape function.
+ // Shape function keys use the underlying primitive function as their
'function',
+ // but the generic 'cpu' target as the target since all shape functions
run
+ // on the host cpu irrespective of where the primitive runs.
+ CCacheKey shape_key(func, host_se_scope_->target);
+ CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
+
+ // Capture the shape function's global var and parameters 'states' in
call
+ // annotations so calling convention can be recovered.
+ // TODO(mbs): Shape cleanup.
+ call_lowered_attrs->metadata.Set("prim_shape_fn_var",
lowered_shape_func->prim_fn_var);
+ call_lowered_attrs->metadata.Set("prim_shape_fn_states",
+
lowered_shape_func->shape_func_param_states);
+ call_lowered_attrs->metadata.Set(
+ "prim_shape_fn_num_inputs",
Integer(static_cast<int>(lowered_shape_func->inputs.size())));
+ call_lowered_attrs->metadata.Set(
+ "prim_shape_fn_num_outputs",
+ Integer(static_cast<int>(lowered_shape_func->outputs.size())));
+ Array<GlobalVar> all_prim_shape_fn_vars;
+ for (const auto& kv : lowered_shape_func->funcs->functions) {
+ CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+ all_prim_shape_fn_vars.push_back(kv.first);
}
- return CallLowered(lowered_func->prim_fn_var, visited_args,
Attrs(call_lowered_attrs),
- type_args, span);
+ call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars",
all_prim_shape_fn_vars);
}
+
+ return CallLowered(cfunc->prim_fn_var, std::move(visited_args),
Attrs(call_lowered_attrs),
+ type_args, std::move(span));
}
std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value)
final {
Var new_var = Downcast<Var>(Mutate(var));
Expr new_value = Mutate(value);
BaseFunc prim_func = ResolveToPrimitive(new_value);
- if (prim_func.defined() && !prim_func->IsInstance<tir::PrimFuncNode>()) {
- // Remember let var is bound to (possibly indirectly) a non-tir
primitive.
- Function func = Downcast<Function>(prim_func);
- primitive_functions_.emplace(var, func);
+ if (prim_func.defined()) {
+ // Remember let var is bound (possibly indirectly) to a primitive
function.
+ primitive_functions_.emplace(var.get(), prim_func);
}
return {new_var, new_value};
}
Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode*
post_let_node) final {
BaseFunc prim_func = ResolveToPrimitive(post_let_node->value);
- if (prim_func.defined() && !prim_func->IsInstance<tir::PrimFuncNode>()) {
+ if (prim_func.defined()) {
// Leaving let var scope
- primitive_functions_.erase(pre_let_node->var);
+ primitive_functions_.erase(pre_let_node->var.get());
}
return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node);
}
Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override {
- if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
- // Nothing to lower inside primitive functions.
+ if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
+ function_node->GetAttr<String>(attr::kExternalSymbol)) {
+ // Nothing to lower inside primitive/external functions.
return GetRef<Function>(function_node);
} else {
return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node);
}
}
Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
- // Passes before lowering might insert a call_lowered to call a function
that has already
- // been lowered. Therefore we might see call_lowered ops here, but we
don't need to do anything
- // because ResolveToPrimitive returns null for all calls where the
call_node->op is an OpNode
- Call call = GetRef<Call>(call_node);
-
- // Look for (indirect) calls to primitives.
- BaseFunc prim_func = ResolveToPrimitive(call_node->op);
- if (!prim_func.defined()) {
- // Not a call_node to a primitive function.
- if (const FunctionNode* fn = call_node->op.as<FunctionNode>()) {
- this->process_fn_(GetRef<Function>(fn));
+ // We can see five forms of calls:
+ // 1. A 'normal' Relay call to a Function with the "primitive" attribute.
We will need
+ // to lower that to a global PrimFunc and rewrite the call to:
+ // call_lowered(@new_global, (arg1, ..., argn), <attributes>)
+ // However there are a few special forms which are excluded from this
treatment, see
+ // below.
+ // 2. A 'normal' Relay call to a Function with the "compiler" attribute.
We will need
+ // to invoke the appropriate BYOC toolchain function to yield a
runtime module and
+ // rewrite the call to the same form as above.
+ // 3. A 'normal' Relay call to a PrimFunc which has already been supplied
via a global
+ // definition. We rewrite to use the call_lowered form, but otherwise
nothing else
+ // needs to be done.
+ // 4. A 'normal' Relay call to a Relay Function without any special
attribute. These
+ // calls are not changed.
+ // 5. A call_lowered call from an earlier invocation of this pass.
+ // Note that ResolveToPrimitive will yield non-null only for cases 1-3.
+
+ // Look for (possibly indirect) calls to primitives.
+ BaseFunc primitive_func = ResolveToPrimitive(call_node->op);
+ if (!primitive_func.defined()) {
+ // Not a call to a primitive function we need to rewrite.
+ if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+ process_fn_(GetRef<Function>(function_node));
}
- return ExprMutator::VisitExpr_(call_node);
+ return DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
}
- // Similarly transform arguments.
- Array<Expr> visited_args;
+ // Prepare the arguments.
+ Array<Expr> new_args;
for (const auto& arg : call_node->args) {
- visited_args.push_back(VisitExpr(arg));
+ new_args.push_back(VisitExpr(arg));
}
- // Already lowered by other means so we don't need to mutate
+ // Special case: device_copies are left as calls to primitive operators
+ // (thus undoing FuseOps) so that each backend can handle them directly.
+ // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just
leave device_copy alone.
+ if (const auto* function_node = primitive_func.as<FunctionNode>()) {
+ DeviceCopyProps device_copy_props =
GetDeviceCopyProps(function_node->body);
+ if (device_copy_props.body.defined()) {
+ ICHECK_EQ(new_args.size(), 1);
+ return DeviceCopy(new_args[0], device_copy_props.src_se_scope,
+ device_copy_props.dst_se_scope);
+ }
+ }
+
+ // Special case: If already lowered by other means then so we don't need
to mutate
// the call but we do need to mutate the arguments
- if (prim_func->IsInstance<tir::PrimFuncNode>()) {
+ if (const auto* prim_func_node = primitive_func.as<tir::PrimFuncNode>()) {
// Function should already be Target annotated by this point
// but the TE Compiler metadata is still needed for the callback
// TODO(Mousius) - Robustify this to not assume we're in the GlobalVar
for Target Hooks
GlobalVar prim_func_var = Downcast<GlobalVar>(call_node->op);
- tir::PrimFunc downcast_prim_func = Downcast<tir::PrimFunc>(prim_func);
+ tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
- Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var,
downcast_prim_func}};
- tir::PrimFunc func_with_metadata =
- WithAttrs(downcast_prim_func, {
- {"prim_fn_var", prim_func_var},
- {"prim_funcs", prim_fns},
- });
+ Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var, prim_func}};
+ tir::PrimFunc func_with_metadata = WithAttrs(prim_func, {
+
{"prim_fn_var", prim_func_var},
+
{"prim_funcs", prim_fns},
+ });
+
+ ICHECK(!IsDynamic(call_node->checked_type()));
+ auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+ call_lowered_attrs->metadata.Set("relay_attrs", primitive_func->attrs);
- this->process_fn_(func_with_metadata);
- return Call(call_node->op, visited_args, call_node->attrs);
+ process_fn_(func_with_metadata);
+ return CallLowered(call_node->op, std::move(new_args),
Attrs(std::move(call_lowered_attrs)),
+ call_node->type_args, call_node->span);
}
+ // Typical case: call to fused primitive Relay Function.
// Find the desired target device.
Target target;
- if (prim_func->GetAttr<String>(attr::kCompiler).defined()) {
+ if (primitive_func->GetAttr<String>(attr::kCompiler).defined()) {
// The generic 'external device' target.
+ // TODO(mbs): Retire once replaced unified BYOC compiler and target
macihnery.
Review comment:
typo: macihnery -> machinery
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]