jroesch commented on a change in pull request #8597:
URL: https://github.com/apache/tvm/pull/8597#discussion_r682888122
##########
File path: src/relay/backend/interpreter.cc
##########
@@ -283,73 +290,153 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
return MakeClosure(func);
}
- Array<Shape> ComputeDynamicShape(const Function& func, const
Array<ObjectRef>& args) {
- CCacheKey key(func, Target("llvm"));
- auto cfunc = compiler_->LowerShapeFunc(key);
- size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
+ /*!
+ * \brief Returns packed function generated for TIR function bound to \p
prim_fn_var.
+ *
+ * \param prim_fn_var Global var for lowered primitve function.
+ * \param all_prim_fn_vars Global vars for all lowered primitive functions
needed by the above
+ * (including itself).
+ */
+ PackedFunc Build(const GlobalVar& prim_fn_var, const Array<GlobalVar>&
all_prim_fn_vars) {
Review comment:
Perhaps we can use a better name then `Build` historically I think Build
is a pretty non-descriptive change.
##########
File path: src/relay/backend/interpreter.cc
##########
@@ -214,8 +217,12 @@ InterpreterState::InterpreterState(Expr current_expr,
InterpreterState::Stack st
class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)>
{
public:
- Interpreter(IRModule mod, Device device, Target target)
- : mod_(mod), device_(device), target_(target),
debug_op_(Op::Get("debug")) {}
+ Interpreter(IRModule mod, IRModule lowered_mod, Device device, Target target)
Review comment:
We should add a todo about coming back to merge these.
##########
File path: src/ir/module.cc
##########
@@ -349,20 +349,23 @@ void IRModuleNode::Update(const IRModule& mod) {
IRModule IRModule::FromExpr(const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
- const tvm::Map<GlobalTypeVar, TypeData>&
type_definitions) {
- auto mod = IRModule(global_funcs, type_definitions);
+ const tvm::Map<GlobalTypeVar, TypeData>&
type_definitions,
+ std::unordered_set<String> import_set, const
std::string& name_hint) {
+ auto mod = IRModule(global_funcs, type_definitions, std::move(import_set));
BaseFunc func;
- std::string gv_name = "main";
+ std::string gv_name = name_hint;
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
gv_name = opt.value();
}
-
} else {
func = relay::Function(relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
}
+ if (gv_name.empty()) {
Review comment:
See above comment, at some point we had a customizable entry point, but
much of the code assumes `main` these days.
##########
File path: src/relay/backend/interpreter.cc
##########
@@ -359,32 +446,25 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
setter(arg_counter + i, arr);
};
- auto ret_type = func->body->checked_type();
size_t out_cnt = 0;
- if (auto rtype = ret_type.as<TupleTypeNode>()) {
+ if (auto rtype = ftn->ret_type.as<TupleTypeNode>()) {
+ // TODO(mbs): Recursive flatten?
Review comment:
See above comment.
##########
File path: src/parser/parser.cc
##########
@@ -417,7 +417,7 @@ class Parser {
* Useful for matching optional tokens, effectively looksahead by one.
*/
bool WhenMatch(const TokenType& token_type) {
- DLOG(INFO) << "Parser::WhenMatch: Peek() == " << Peek();
+ // DLOG(INFO) << "Parser::WhenMatch: Peek() == " << Peek();
Review comment:
Add logging back? this is one reason why I want us to move to hierarchal
version of logging, much easier to filter post facto.
##########
File path: src/relay/backend/interpreter.cc
##########
@@ -283,73 +290,153 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
return MakeClosure(func);
}
- Array<Shape> ComputeDynamicShape(const Function& func, const
Array<ObjectRef>& args) {
- CCacheKey key(func, Target("llvm"));
- auto cfunc = compiler_->LowerShapeFunc(key);
- size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
+ /*!
+ * \brief Returns packed function generated for TIR function bound to \p
prim_fn_var.
+ *
+ * \param prim_fn_var Global var for lowered primitve function.
+ * \param all_prim_fn_vars Global vars for all lowered primitive functions
needed by the above
+ * (including itself).
+ */
+ PackedFunc Build(const GlobalVar& prim_fn_var, const Array<GlobalVar>&
all_prim_fn_vars) {
+ auto itr = built_primitives_.find(prim_fn_var->name_hint);
+ if (itr != built_primitives_.end()) {
+ return itr->second;
+ }
+
+ // Project out just the primitive(s) we need.
+ // (Primitives may depend on other primitives).
+ IRModule lowered_projected_mod;
+ for (const auto& var : all_prim_fn_vars) {
+ ICHECK(lowered_mod_->ContainGlobalVar(var->name_hint));
+ lowered_projected_mod->Add(var, lowered_mod_->Lookup(var->name_hint));
+ }
+ // Build the projected module.
+ runtime::Module runtime_module;
+ if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
+ // TODO(mbs): Deprecate hooks.
+ runtime_module = (*f)(lowered_projected_mod, target_);
+ } else {
+ runtime_module = build(lowered_projected_mod, target_,
/*target_host=*/Target(nullptr));
+ }
+
+ // Extract all the packed functions.
+ for (const auto& var : all_prim_fn_vars) {
+ PackedFunc packed_func = runtime_module.GetFunction(var->name_hint);
+ ICHECK_NOTNULL(packed_func);
+ built_primitives_.emplace(var->name_hint, packed_func);
+ }
+
+ // Return just what we need for this call.
+ itr = built_primitives_.find(prim_fn_var->name_hint);
+ ICHECK(itr != built_primitives_.end())
+ << "Can't find built function " << prim_fn_var->name_hint;
+ ICHECK_NOTNULL(itr->second);
+ return itr->second;
+ }
+
+ /*!
+ * \brief Call the primitive shape function bound to \p prim_shape_fn_var
passing the
+ * shapes of args, and return the resulting shapes.
+ *
+ * \param prim_shape_fn_var Global var bound to lowered shape function.
+ * \param all_prim_shape_fn_vars All the global vars needed to build the
above, including
+ * the shape function itself.
+ * \param prim_shape_fn_states For each arg, indicate whether the primitive
shape function
+ * requires the shape of the argument and/or the actual argument tensor.
+ * \param num_shape_inputs The number of inputs, after accounting for both
shapes vs data
+ * inputs and unfolding of tuple types.
+ * \param num_shape_outputs The number of outputs, after accounting for
unfolding of
+ * tuple types.
+ * \param args Arguments to the underlying primitive this shape function is
for.
+ * \return Expected shapes of the underlying primitive's outputs.
+ */
+ Array<Shape> ComputeDynamicShape(const GlobalVar& prim_shape_fn_var,
+ const Array<GlobalVar>&
all_prim_shape_fn_vars,
+ const Array<Integer>& prim_shape_fn_states,
+ size_t num_shape_inputs, size_t
num_shape_outputs,
+ const Array<ObjectRef>& args) {
+ ICHECK(prim_shape_fn_var.defined());
+ ICHECK(prim_shape_fn_states.defined());
+ ICHECK(prim_shape_fn_var->checked_type().defined());
+ // The function type is that of the original primitive rather than the
shape function
+ // itself. We currently can't express shape function types in Relay.
+ const FuncTypeNode* ftn =
prim_shape_fn_var->checked_type().as<FuncTypeNode>();
+ ICHECK(ftn);
+ ICHECK_EQ(prim_shape_fn_states.size(), ftn->arg_types.size());
+ ICHECK_EQ(args.size(), ftn->arg_types.size());
+ // num_shape_inputs will account for which primitive function arguments
are dynamic,
+ // whether the shape and or data needs to be passed, and flattening of
ADTs.
+ // Similarly, num_shape_outputs will account for flattening of ADTs.
+
+ PackedFunc packed_shape_func = Build(prim_shape_fn_var,
all_prim_shape_fn_vars);
+
+ size_t arity = num_shape_inputs + num_shape_outputs;
std::vector<TVMValue> values(arity);
std::vector<int> codes(arity);
TVMArgsSetter setter(values.data(), codes.data());
- std::vector<NDArray> inputs(cfunc->inputs.size());
- std::vector<NDArray> outputs(cfunc->outputs.size());
+ std::vector<NDArray> inputs(num_shape_inputs);
+ std::vector<NDArray> outputs(num_shape_outputs);
Device cpu_dev;
cpu_dev.device_type = kDLCPU;
cpu_dev.device_id = 0;
- auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) {
+ auto fset_shape_input = [&](size_t i, ObjectRef val) {
+ DCHECK(i < num_shape_inputs);
auto nd_array = Downcast<NDArray>(val);
- if (need_shape) {
- int64_t ndim = nd_array.Shape().size();
- NDArray shape_arr;
- if (ndim == 0) {
- shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_dev);
- } else {
- shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev);
- int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
- for (auto j = 0; j < ndim; ++j) {
- data[j] = nd_array.Shape()[j];
- }
- }
- inputs[i] = shape_arr;
- setter(i, shape_arr);
+ int64_t ndim = nd_array.Shape().size();
+ NDArray shape_arr;
+ if (ndim == 0) {
+ shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_dev);
} else {
- auto arr = nd_array.CopyTo(cpu_dev);
- inputs[i] = arr;
- setter(i, arr);
+ shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev);
+ int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
+ for (auto j = 0; j < ndim; ++j) {
+ data[j] = nd_array.Shape()[j];
+ }
}
+ inputs[i] = shape_arr;
+ setter(i, shape_arr);
+ };
+
+ auto fset_data_input = [&](size_t i, ObjectRef val) {
+ DCHECK(i < num_shape_inputs);
+ auto nd_array = Downcast<NDArray>(val);
+ auto arr = nd_array.CopyTo(cpu_dev);
+ inputs[i] = arr;
+ setter(i, arr);
};
size_t arg_counter = 0;
for (size_t i = 0; i < args.size(); ++i) {
auto arg = args[i];
- auto param = func->params[i];
- int state = cfunc->shape_func_param_states[i]->value;
+ int64_t state = prim_shape_fn_states[i]->value;
Review comment:
You can probably clean this code using the helpers:
```
/*! \brief Pack the sequence of vectors according to `ty`. */
Expr ToTupleType(const Type& ty, const std::vector<Expr>& exprs);
/*! \brief Unpack an expression according to the type into a sequence */
std::vector<Expr> FromTupleType(const Type& type, const Expr& expr);
/*! Return the "flattened" type which matches `FromTupleType`. */
std::vector<TensorType> FlattenTupleType(const Type& type);
```
I also added some missing comments.
See
https://github.com/apache/tvm/blob/main/src/relay/op/memory/memory.h#L39.
This is used by the `memory_alloc.cc` pass.
##########
File path: src/parser/tokenizer.h
##########
@@ -339,7 +339,7 @@ struct Tokenizer {
int line = this->line;
int col = this->col;
auto next = Peek();
- DLOG(INFO) << "tvm::parser::TokenizeOnce: next=" << next;
+ // DLOG(INFO) << "tvm::parser::TokenizeOnce: next=" << next;
Review comment:
Same as above.
##########
File path: src/relay/backend/interpreter.cc
##########
@@ -476,17 +552,19 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
};
Array<Shape> out_shapes;
- auto ret_type = func->body->checked_type();
- bool is_dyn = IsDynamic(ret_type);
+ bool is_dyn = IsDynamic(ftn->ret_type);
if (is_dyn) {
- ICHECK(func->HasNonzeroAttr(attr::kPrimitive));
- out_shapes = ComputeDynamicShape(func, args);
+ ICHECK(prim_shape_fn_var.defined());
+ ICHECK(prim_shape_fn_states.defined());
+ out_shapes =
+ ComputeDynamicShape(prim_shape_fn_var, all_prim_shape_fn_vars,
prim_shape_fn_states,
+ num_shape_inputs, num_shape_outputs, args);
}
- PackedFunc packed_func = compiler_->JIT(CCacheKey(func, target_));
TVMRetValue rv;
- if (const TupleTypeNode* rtype =
func->body->checked_type().as<TupleTypeNode>()) {
+ if (const TupleTypeNode* rtype = ftn->ret_type.as<TupleTypeNode>()) {
+ // TODO(mbs): Recursive flatten?
Review comment:
See above comment.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]