mbs-octoml commented on a change in pull request #8597:
URL: https://github.com/apache/tvm/pull/8597#discussion_r683041378
##########
File path: src/relay/backend/interpreter.cc
##########
@@ -283,73 +290,153 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
return MakeClosure(func);
}
- Array<Shape> ComputeDynamicShape(const Function& func, const
Array<ObjectRef>& args) {
- CCacheKey key(func, Target("llvm"));
- auto cfunc = compiler_->LowerShapeFunc(key);
- size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
+ /*!
+ * \brief Returns packed function generated for TIR function bound to \p
prim_fn_var.
+ *
+ * \param prim_fn_var Global var for lowered primitve function.
+ * \param all_prim_fn_vars Global vars for all lowered primitive functions
needed by the above
+ * (including itself).
+ */
+ PackedFunc Build(const GlobalVar& prim_fn_var, const Array<GlobalVar>&
all_prim_fn_vars) {
+ auto itr = built_primitives_.find(prim_fn_var->name_hint);
+ if (itr != built_primitives_.end()) {
+ return itr->second;
+ }
+
+ // Project out just the primitive(s) we need.
+ // (Primitives may depend on other primitives).
+ IRModule lowered_projected_mod;
+ for (const auto& var : all_prim_fn_vars) {
+ ICHECK(lowered_mod_->ContainGlobalVar(var->name_hint));
+ lowered_projected_mod->Add(var, lowered_mod_->Lookup(var->name_hint));
+ }
+ // Build the projected module.
+ runtime::Module runtime_module;
+ if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
+ // TODO(mbs): Deprecate hooks.
+ runtime_module = (*f)(lowered_projected_mod, target_);
+ } else {
+ runtime_module = build(lowered_projected_mod, target_,
/*target_host=*/Target(nullptr));
+ }
+
+ // Extract all the packed functions.
+ for (const auto& var : all_prim_fn_vars) {
+ PackedFunc packed_func = runtime_module.GetFunction(var->name_hint);
+ ICHECK_NOTNULL(packed_func);
+ built_primitives_.emplace(var->name_hint, packed_func);
+ }
+
+ // Return just what we need for this call.
+ itr = built_primitives_.find(prim_fn_var->name_hint);
+ ICHECK(itr != built_primitives_.end())
+ << "Can't find built function " << prim_fn_var->name_hint;
+ ICHECK_NOTNULL(itr->second);
+ return itr->second;
+ }
+
+ /*!
+ * \brief Call the primitive shape function bound to \p prim_shape_fn_var
passing the
+ * shapes of args, and return the resulting shapes.
+ *
+ * \param prim_shape_fn_var Global var bound to lowered shape function.
+ * \param all_prim_shape_fn_vars All the global vars needed to build the
above, including
+ * the shape function itself.
+ * \param prim_shape_fn_states For each arg, indicate whether the primitive
shape function
+ * requires the shape of the argument and/or the actual argument tensor.
+ * \param num_shape_inputs The number of inputs, after accounting for both
shapes vs data
+ * inputs and unfolding of tuple types.
+ * \param num_shape_outputs The number of outputs, after accounting for
unfolding of
+ * tuple types.
+ * \param args Arguments to the underlying primitive this shape function is
for.
+ * \return Expected shapes of the underlying primitive's outputs.
+ */
+ Array<Shape> ComputeDynamicShape(const GlobalVar& prim_shape_fn_var,
+ const Array<GlobalVar>&
all_prim_shape_fn_vars,
+ const Array<Integer>& prim_shape_fn_states,
+ size_t num_shape_inputs, size_t
num_shape_outputs,
+ const Array<ObjectRef>& args) {
+ ICHECK(prim_shape_fn_var.defined());
+ ICHECK(prim_shape_fn_states.defined());
+ ICHECK(prim_shape_fn_var->checked_type().defined());
+ // The function type is that of the original primitive rather than the
shape function
+ // itself. We currently can't express shape function types in Relay.
+ const FuncTypeNode* ftn =
prim_shape_fn_var->checked_type().as<FuncTypeNode>();
+ ICHECK(ftn);
+ ICHECK_EQ(prim_shape_fn_states.size(), ftn->arg_types.size());
+ ICHECK_EQ(args.size(), ftn->arg_types.size());
+ // num_shape_inputs will account for which primitive function arguments
are dynamic,
+ // whether the shape and or data needs to be passed, and flattening of
ADTs.
+ // Similarly, num_shape_outputs will account for flattening of ADTs.
+
+ PackedFunc packed_shape_func = Build(prim_shape_fn_var,
all_prim_shape_fn_vars);
+
+ size_t arity = num_shape_inputs + num_shape_outputs;
std::vector<TVMValue> values(arity);
std::vector<int> codes(arity);
TVMArgsSetter setter(values.data(), codes.data());
- std::vector<NDArray> inputs(cfunc->inputs.size());
- std::vector<NDArray> outputs(cfunc->outputs.size());
+ std::vector<NDArray> inputs(num_shape_inputs);
+ std::vector<NDArray> outputs(num_shape_outputs);
Device cpu_dev;
cpu_dev.device_type = kDLCPU;
cpu_dev.device_id = 0;
- auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) {
+ auto fset_shape_input = [&](size_t i, ObjectRef val) {
+ DCHECK(i < num_shape_inputs);
auto nd_array = Downcast<NDArray>(val);
- if (need_shape) {
- int64_t ndim = nd_array.Shape().size();
- NDArray shape_arr;
- if (ndim == 0) {
- shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_dev);
- } else {
- shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev);
- int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
- for (auto j = 0; j < ndim; ++j) {
- data[j] = nd_array.Shape()[j];
- }
- }
- inputs[i] = shape_arr;
- setter(i, shape_arr);
+ int64_t ndim = nd_array.Shape().size();
+ NDArray shape_arr;
+ if (ndim == 0) {
+ shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_dev);
} else {
- auto arr = nd_array.CopyTo(cpu_dev);
- inputs[i] = arr;
- setter(i, arr);
+ shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev);
+ int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
+ for (auto j = 0; j < ndim; ++j) {
+ data[j] = nd_array.Shape()[j];
+ }
}
+ inputs[i] = shape_arr;
+ setter(i, shape_arr);
+ };
+
+ auto fset_data_input = [&](size_t i, ObjectRef val) {
+ DCHECK(i < num_shape_inputs);
+ auto nd_array = Downcast<NDArray>(val);
+ auto arr = nd_array.CopyTo(cpu_dev);
+ inputs[i] = arr;
+ setter(i, arr);
};
size_t arg_counter = 0;
for (size_t i = 0; i < args.size(); ++i) {
auto arg = args[i];
- auto param = func->params[i];
- int state = cfunc->shape_func_param_states[i]->value;
+ int64_t state = prim_shape_fn_states[i]->value;
Review comment:
This kNeedInputData/kNeedInputShape state per arg is a bit gnarly since
it's gathered as a side-effort of building the shape function. So I'm gonna
leave a TODO to cleanup the way we convey this to call sites rather than kick
that hornets nest right now.
Done for TupleType flattening.
**However** is ADT::Tuple coherent with FlattenTupleType? Ie is this legit?
```
std::vector<TensorType> result_tensor_types =
FlattenTupleType(ftn->ret_type);
std::vector<ObjectRef> result_ndarrays;
for (int i = 0; i < result_tensor_types.size(); ++i) {
...
result_ndarrays.emplace_back(...)
}
if (result_tensor_types.size() == 1) {
return result_ndarrays[0];
} else {
return ADT::Tuple(result_ndarrays);
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]