mbs-octoml commented on a change in pull request #8597:
URL: https://github.com/apache/tvm/pull/8597#discussion_r683006012



##########
File path: include/tvm/ir/module.h
##########
@@ -307,20 +307,31 @@ class IRModule : public ObjectRef {
   }
 
   /*!
-   * \brief Construct a module from a standalone expression.
+   * \brief Constructs a module from a standalone expression \p expr.
    *
-   * Allows one to optionally pass a global function map and
-   * map of type definitions as well.
+   * If \p expr is a function it will be bound directly. Otherwise a function 
over the free
+   * variables of \p expr (possibly none) with \p expr as body is created and 
bound.
+   *
+   * The function is bound to, in preference order:
+   *  - The "global_symbol" attribute of \p expr, if it is a function with 
that attribute.
+   *  - \p name_hint, if non-empty.
+   *  - "main"
+   *
+   * Additional global functions and type definitions may be included in the 
result module.
    *
    * \param expr The expression to set as the main function to the module.
-   * \param global_funcs The global function map.
-   * \param type_definitions Map of global type definitions
+   * \param global_funcs The global function map. Default empty.
+   * \param type_definitions Map of global type definitions. Default empty.
+   * \param import_set Set of external modules already imported. Default empty.
+   * \param name_hint Name hint for global var to bind to \p expr. Default 
empty.
    *
-   * \returns A module with expr set as the main function.
+   * \returns A module with \p expr set as the main function.
    */
   TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
                                    const Map<GlobalVar, BaseFunc>& 
global_funcs = {},
-                                   const Map<GlobalTypeVar, TypeData>& 
type_definitions = {});
+                                   const Map<GlobalTypeVar, TypeData>& 
type_definitions = {},
+                                   std::unordered_set<String> import_set = {},
+                                   const std::string& name_hint = 
std::string());

Review comment:
       Done. Kept FromExpr for the simple case (just an expr, always implicitly 
bound to 'main'). General case now FromExprInContext. Added yet another 
GetUniqueName into IRModule (though left private & not ffiable for now). We 
should bring the N copies of this under control.

##########
File path: python/tvm/relay/analysis/analysis.py
##########
@@ -433,8 +433,7 @@ def get_calibration_data(mod, data):
     mod = _ffi_api.get_calibrate_module(mod)
     mod = transform.Inline()(mod)
 
-    ref_ex = build_module.create_executor("graph", mod=mod, device=cpu(0))
-    ref_res = ref_ex.evaluate()(**data)
+    ref_res = build_module.create_executor("graph", mod=mod, 
device=cpu(0)).evaluate()(**data)

Review comment:
       Agree, and I'm very close to doing it now since I went and rejigged 
**all** the create_executor/evaluate calls anyway.

##########
File path: python/tvm/relay/analysis/analysis.py
##########
@@ -433,8 +433,7 @@ def get_calibration_data(mod, data):
     mod = _ffi_api.get_calibrate_module(mod)
     mod = transform.Inline()(mod)
 
-    ref_ex = build_module.create_executor("graph", mod=mod, device=cpu(0))
-    ref_res = ref_ex.evaluate()(**data)
+    ref_res = build_module.create_executor("graph", mod=mod, 
device=cpu(0)).evaluate()(**data)

Review comment:
       Agree, and I'm very close to doing it now since I went and rejigged 
**all** the create_executor/evaluate calls anyway. Left a TODO(mbs). Once we 
have issue tracking conventions I'll rework the TODOs into TODO(some specific 
issue) so we can organize this.

##########
File path: src/ir/module.cc
##########
@@ -349,20 +349,23 @@ void IRModuleNode::Update(const IRModule& mod) {
 
 IRModule IRModule::FromExpr(const RelayExpr& expr,
                             const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
-                            const tvm::Map<GlobalTypeVar, TypeData>& 
type_definitions) {
-  auto mod = IRModule(global_funcs, type_definitions);
+                            const tvm::Map<GlobalTypeVar, TypeData>& 
type_definitions,
+                            std::unordered_set<String> import_set, const 
std::string& name_hint) {
+  auto mod = IRModule(global_funcs, type_definitions, std::move(import_set));
   BaseFunc func;
-  std::string gv_name = "main";
+  std::string gv_name = name_hint;
 
   if (auto* func_node = expr.as<BaseFuncNode>()) {
     func = GetRef<BaseFunc>(func_node);
     if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
       gv_name = opt.value();
     }
-
   } else {
     func = relay::Function(relay::FreeVars(expr), expr, Type(), 
relay::FreeTypeVars(expr, mod), {});
   }
+  if (gv_name.empty()) {

Review comment:
       Yeah, 'main' is endemic now. I got a bit of the way there addressing 
your FromExpr feedback above. Started to remove the dependence on 'main' but 
that got out of hand quickly so put it back in the bottle.

##########
File path: src/parser/parser.cc
##########
@@ -417,7 +417,7 @@ class Parser {
    * Useful for matching optional tokens, effectively looksahead by one.
    */
   bool WhenMatch(const TokenType& token_type) {
-    DLOG(INFO) << "Parser::WhenMatch: Peek() == " << Peek();
+    // DLOG(INFO) << "Parser::WhenMatch: Peek() == " << Peek();

Review comment:
       Ah, right! I started on per-file VLOGs, but hosted that out of here. 
Commented these out since I couldn't see the wood for the trees, will undo 
sorry.

##########
File path: src/relay/backend/interpreter.cc
##########
@@ -214,8 +217,12 @@ InterpreterState::InterpreterState(Expr current_expr, 
InterpreterState::Stack st
 class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
                     PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> 
{
  public:
-  Interpreter(IRModule mod, Device device, Target target)
-      : mod_(mod), device_(device), target_(target), 
debug_op_(Op::Get("debug")) {}
+  Interpreter(IRModule mod, IRModule lowered_mod, Device device, Target target)

Review comment:
       Done.

##########
File path: src/relay/backend/interpreter.cc
##########
@@ -283,73 +290,153 @@ class Interpreter : public ExprFunctor<ObjectRef(const 
Expr& n)>,
     return MakeClosure(func);
   }
 
-  Array<Shape> ComputeDynamicShape(const Function& func, const 
Array<ObjectRef>& args) {
-    CCacheKey key(func, Target("llvm"));
-    auto cfunc = compiler_->LowerShapeFunc(key);
-    size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
+  /*!
+   * \brief Returns packed function generated for TIR function bound to \p 
prim_fn_var.
+   *
+   * \param prim_fn_var Global var for lowered primitve function.
+   * \param all_prim_fn_vars Global vars for all lowered primitive functions 
needed by the above
+   * (including itself).
+   */
+  PackedFunc Build(const GlobalVar& prim_fn_var, const Array<GlobalVar>& 
all_prim_fn_vars) {

Review comment:
       Done. PrimitiveToPackedFunc.

##########
File path: src/relay/backend/interpreter.cc
##########
@@ -283,73 +290,153 @@ class Interpreter : public ExprFunctor<ObjectRef(const 
Expr& n)>,
     return MakeClosure(func);
   }
 
-  Array<Shape> ComputeDynamicShape(const Function& func, const 
Array<ObjectRef>& args) {
-    CCacheKey key(func, Target("llvm"));
-    auto cfunc = compiler_->LowerShapeFunc(key);
-    size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
+  /*!
+   * \brief Returns packed function generated for TIR function bound to \p 
prim_fn_var.
+   *
+   * \param prim_fn_var Global var for lowered primitve function.
+   * \param all_prim_fn_vars Global vars for all lowered primitive functions 
needed by the above
+   * (including itself).
+   */
+  PackedFunc Build(const GlobalVar& prim_fn_var, const Array<GlobalVar>& 
all_prim_fn_vars) {
+    auto itr = built_primitives_.find(prim_fn_var->name_hint);
+    if (itr != built_primitives_.end()) {
+      return itr->second;
+    }
+
+    // Project out just the primitive(s) we need.
+    // (Primitives may depend on other primitives).
+    IRModule lowered_projected_mod;
+    for (const auto& var : all_prim_fn_vars) {
+      ICHECK(lowered_mod_->ContainGlobalVar(var->name_hint));
+      lowered_projected_mod->Add(var, lowered_mod_->Lookup(var->name_hint));
+    }
 
+    // Build the projected module.
+    runtime::Module runtime_module;
+    if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
+      // TODO(mbs): Deprecate hooks.
+      runtime_module = (*f)(lowered_projected_mod, target_);
+    } else {
+      runtime_module = build(lowered_projected_mod, target_, 
/*target_host=*/Target(nullptr));
+    }
+
+    // Extract all the packed functions.
+    for (const auto& var : all_prim_fn_vars) {
+      PackedFunc packed_func = runtime_module.GetFunction(var->name_hint);
+      ICHECK_NOTNULL(packed_func);
+      built_primitives_.emplace(var->name_hint, packed_func);
+    }
+
+    // Return just what we need for this call.
+    itr = built_primitives_.find(prim_fn_var->name_hint);
+    ICHECK(itr != built_primitives_.end())
+        << "Can't find built function " << prim_fn_var->name_hint;
+    ICHECK_NOTNULL(itr->second);
+    return itr->second;
+  }
+
+  /*!
+   * \brief Call the primitive shape function bound to \p prim_shape_fn_var 
passing the
+   * shapes of args, and return the resulting shapes.
+   *
+   * \param prim_shape_fn_var Global var bound to lowered shape function.
+   * \param all_prim_shape_fn_vars All the global vars needed to build the 
above, including
+   * the shape function itself.
+   * \param prim_shape_fn_states For each arg, indicate whether the primitive 
shape function
+   * requires the shape of the argument and/or the actual argument tensor.
+   * \param num_shape_inputs The number of inputs, after accounting for both 
shapes vs data
+   * inputs and unfolding of tuple types.
+   * \param num_shape_outputs The number of outputs, after accounting for 
unfolding of
+   * tuple types.
+   * \param args Arguments to the underlying primitive this shape function is 
for.
+   * \return Expected shapes of the underlying primitive's outputs.
+   */
+  Array<Shape> ComputeDynamicShape(const GlobalVar& prim_shape_fn_var,
+                                   const Array<GlobalVar>& 
all_prim_shape_fn_vars,
+                                   const Array<Integer>& prim_shape_fn_states,
+                                   size_t num_shape_inputs, size_t 
num_shape_outputs,
+                                   const Array<ObjectRef>& args) {
+    ICHECK(prim_shape_fn_var.defined());
+    ICHECK(prim_shape_fn_states.defined());
+    ICHECK(prim_shape_fn_var->checked_type().defined());
+    // The function type is that of the original primitive rather than the 
shape function
+    // itself. We currently can't express shape function types in Relay.
+    const FuncTypeNode* ftn = 
prim_shape_fn_var->checked_type().as<FuncTypeNode>();
+    ICHECK(ftn);
+    ICHECK_EQ(prim_shape_fn_states.size(), ftn->arg_types.size());
+    ICHECK_EQ(args.size(), ftn->arg_types.size());
+    // num_shape_inputs will account for which primitive function arguments 
are dynamic,
+    // whether the shape and or data needs to be passed, and flattening of 
ADTs.
+    // Similarly, num_shape_outputs will account for flattening of ADTs.
+
+    PackedFunc packed_shape_func = Build(prim_shape_fn_var, 
all_prim_shape_fn_vars);
+
+    size_t arity = num_shape_inputs + num_shape_outputs;
     std::vector<TVMValue> values(arity);
     std::vector<int> codes(arity);
     TVMArgsSetter setter(values.data(), codes.data());
-    std::vector<NDArray> inputs(cfunc->inputs.size());
-    std::vector<NDArray> outputs(cfunc->outputs.size());
+    std::vector<NDArray> inputs(num_shape_inputs);
+    std::vector<NDArray> outputs(num_shape_outputs);
 
     Device cpu_dev;
     cpu_dev.device_type = kDLCPU;
     cpu_dev.device_id = 0;
 
-    auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) {
+    auto fset_shape_input = [&](size_t i, ObjectRef val) {
+      DCHECK(i < num_shape_inputs);
       auto nd_array = Downcast<NDArray>(val);
-      if (need_shape) {
-        int64_t ndim = nd_array.Shape().size();
-        NDArray shape_arr;
-        if (ndim == 0) {
-          shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_dev);
-        } else {
-          shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev);
-          int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
-          for (auto j = 0; j < ndim; ++j) {
-            data[j] = nd_array.Shape()[j];
-          }
-        }
-        inputs[i] = shape_arr;
-        setter(i, shape_arr);
+      int64_t ndim = nd_array.Shape().size();
+      NDArray shape_arr;
+      if (ndim == 0) {
+        shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_dev);
       } else {
-        auto arr = nd_array.CopyTo(cpu_dev);
-        inputs[i] = arr;
-        setter(i, arr);
+        shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev);
+        int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
+        for (auto j = 0; j < ndim; ++j) {
+          data[j] = nd_array.Shape()[j];
+        }
       }
+      inputs[i] = shape_arr;
+      setter(i, shape_arr);
+    };
+
+    auto fset_data_input = [&](size_t i, ObjectRef val) {
+      DCHECK(i < num_shape_inputs);
+      auto nd_array = Downcast<NDArray>(val);
+      auto arr = nd_array.CopyTo(cpu_dev);
+      inputs[i] = arr;
+      setter(i, arr);
     };
 
     size_t arg_counter = 0;
     for (size_t i = 0; i < args.size(); ++i) {
       auto arg = args[i];
-      auto param = func->params[i];
-      int state = cfunc->shape_func_param_states[i]->value;
+      int64_t state = prim_shape_fn_states[i]->value;

Review comment:
       This kNeedInputData/kNeedInputShape state per arg is a bit gnarly since 
it's gathered as a side-effort of building the shape function. So I'm gonna 
leave a TODO to cleanup the way we convey this to call sites rather than kick 
that hornets nest right now.
   
   Done for TupleType flattening.
   
   **However** is ADT::Tuple coherent with FlattenTupleType? Ie is this legit?
   ```
   std::vector<TensorType> result_tensor_types = 
FlattenTupleType(ftn->ret_type);
   std::vector<ObjectRef> result_ndarrays;
   for (int i = 0; i < result_tensor_types.size(); ++i) {
      ...
      result_ndarrays.emplace_back(...)
   }
   if (result_tensor_types.size() == 1) {
         return result_ndarrays[0];
   } else {
     return ADT::Tuple(result_ndarrays);
   }
   ```
   

##########
File path: src/relay/backend/interpreter.cc
##########
@@ -283,73 +290,153 @@ class Interpreter : public ExprFunctor<ObjectRef(const 
Expr& n)>,
     return MakeClosure(func);
   }
 
-  Array<Shape> ComputeDynamicShape(const Function& func, const 
Array<ObjectRef>& args) {
-    CCacheKey key(func, Target("llvm"));
-    auto cfunc = compiler_->LowerShapeFunc(key);
-    size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
+  /*!
+   * \brief Returns packed function generated for TIR function bound to \p 
prim_fn_var.
+   *
+   * \param prim_fn_var Global var for lowered primitve function.
+   * \param all_prim_fn_vars Global vars for all lowered primitive functions 
needed by the above
+   * (including itself).
+   */
+  PackedFunc Build(const GlobalVar& prim_fn_var, const Array<GlobalVar>& 
all_prim_fn_vars) {
+    auto itr = built_primitives_.find(prim_fn_var->name_hint);
+    if (itr != built_primitives_.end()) {
+      return itr->second;
+    }
+
+    // Project out just the primitive(s) we need.
+    // (Primitives may depend on other primitives).
+    IRModule lowered_projected_mod;
+    for (const auto& var : all_prim_fn_vars) {
+      ICHECK(lowered_mod_->ContainGlobalVar(var->name_hint));
+      lowered_projected_mod->Add(var, lowered_mod_->Lookup(var->name_hint));
+    }
 
+    // Build the projected module.
+    runtime::Module runtime_module;
+    if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
+      // TODO(mbs): Deprecate hooks.
+      runtime_module = (*f)(lowered_projected_mod, target_);
+    } else {
+      runtime_module = build(lowered_projected_mod, target_, 
/*target_host=*/Target(nullptr));
+    }
+
+    // Extract all the packed functions.
+    for (const auto& var : all_prim_fn_vars) {
+      PackedFunc packed_func = runtime_module.GetFunction(var->name_hint);
+      ICHECK_NOTNULL(packed_func);
+      built_primitives_.emplace(var->name_hint, packed_func);
+    }
+
+    // Return just what we need for this call.
+    itr = built_primitives_.find(prim_fn_var->name_hint);
+    ICHECK(itr != built_primitives_.end())
+        << "Can't find built function " << prim_fn_var->name_hint;
+    ICHECK_NOTNULL(itr->second);
+    return itr->second;
+  }
+
+  /*!
+   * \brief Call the primitive shape function bound to \p prim_shape_fn_var 
passing the
+   * shapes of args, and return the resulting shapes.
+   *
+   * \param prim_shape_fn_var Global var bound to lowered shape function.
+   * \param all_prim_shape_fn_vars All the global vars needed to build the 
above, including
+   * the shape function itself.
+   * \param prim_shape_fn_states For each arg, indicate whether the primitive 
shape function
+   * requires the shape of the argument and/or the actual argument tensor.
+   * \param num_shape_inputs The number of inputs, after accounting for both 
shapes vs data
+   * inputs and unfolding of tuple types.
+   * \param num_shape_outputs The number of outputs, after accounting for 
unfolding of
+   * tuple types.
+   * \param args Arguments to the underlying primitive this shape function is 
for.
+   * \return Expected shapes of the underlying primitive's outputs.
+   */
+  Array<Shape> ComputeDynamicShape(const GlobalVar& prim_shape_fn_var,
+                                   const Array<GlobalVar>& 
all_prim_shape_fn_vars,
+                                   const Array<Integer>& prim_shape_fn_states,
+                                   size_t num_shape_inputs, size_t 
num_shape_outputs,
+                                   const Array<ObjectRef>& args) {
+    ICHECK(prim_shape_fn_var.defined());
+    ICHECK(prim_shape_fn_states.defined());
+    ICHECK(prim_shape_fn_var->checked_type().defined());
+    // The function type is that of the original primitive rather than the 
shape function
+    // itself. We currently can't express shape function types in Relay.
+    const FuncTypeNode* ftn = 
prim_shape_fn_var->checked_type().as<FuncTypeNode>();
+    ICHECK(ftn);
+    ICHECK_EQ(prim_shape_fn_states.size(), ftn->arg_types.size());
+    ICHECK_EQ(args.size(), ftn->arg_types.size());
+    // num_shape_inputs will account for which primitive function arguments 
are dynamic,
+    // whether the shape and or data needs to be passed, and flattening of 
ADTs.
+    // Similarly, num_shape_outputs will account for flattening of ADTs.
+
+    PackedFunc packed_shape_func = Build(prim_shape_fn_var, 
all_prim_shape_fn_vars);
+
+    size_t arity = num_shape_inputs + num_shape_outputs;
     std::vector<TVMValue> values(arity);
     std::vector<int> codes(arity);
     TVMArgsSetter setter(values.data(), codes.data());
-    std::vector<NDArray> inputs(cfunc->inputs.size());
-    std::vector<NDArray> outputs(cfunc->outputs.size());
+    std::vector<NDArray> inputs(num_shape_inputs);
+    std::vector<NDArray> outputs(num_shape_outputs);
 
     Device cpu_dev;
     cpu_dev.device_type = kDLCPU;
     cpu_dev.device_id = 0;
 
-    auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) {
+    auto fset_shape_input = [&](size_t i, ObjectRef val) {
+      DCHECK(i < num_shape_inputs);
       auto nd_array = Downcast<NDArray>(val);
-      if (need_shape) {
-        int64_t ndim = nd_array.Shape().size();
-        NDArray shape_arr;
-        if (ndim == 0) {
-          shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_dev);
-        } else {
-          shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev);
-          int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
-          for (auto j = 0; j < ndim; ++j) {
-            data[j] = nd_array.Shape()[j];
-          }
-        }
-        inputs[i] = shape_arr;
-        setter(i, shape_arr);
+      int64_t ndim = nd_array.Shape().size();
+      NDArray shape_arr;
+      if (ndim == 0) {
+        shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_dev);
       } else {
-        auto arr = nd_array.CopyTo(cpu_dev);
-        inputs[i] = arr;
-        setter(i, arr);
+        shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev);
+        int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
+        for (auto j = 0; j < ndim; ++j) {
+          data[j] = nd_array.Shape()[j];
+        }
       }
+      inputs[i] = shape_arr;
+      setter(i, shape_arr);
+    };
+
+    auto fset_data_input = [&](size_t i, ObjectRef val) {
+      DCHECK(i < num_shape_inputs);
+      auto nd_array = Downcast<NDArray>(val);
+      auto arr = nd_array.CopyTo(cpu_dev);
+      inputs[i] = arr;
+      setter(i, arr);
     };
 
     size_t arg_counter = 0;
     for (size_t i = 0; i < args.size(); ++i) {
       auto arg = args[i];
-      auto param = func->params[i];
-      int state = cfunc->shape_func_param_states[i]->value;
+      int64_t state = prim_shape_fn_states[i]->value;

Review comment:
       This kNeedInputData/kNeedInputShape state per arg is a bit gnarly since 
it's gathered as a side-effort of building the shape function. So I'm gonna 
leave a TODO to cleanup the way we convey this to call sites rather than kick 
that hornets nest right now.
   
   Done for TupleType flattening.
   
   **However** is ADT::Tuple coherent with FlattenTupleType? Ie is this legit?
   ```
   std::vector<TensorType> result_tensor_types = 
FlattenTupleType(ftn->ret_type);
   std::vector<ObjectRef> result_ndarrays;
   for (int i = 0; i < result_tensor_types.size(); ++i) {
      ...
      result_ndarrays.emplace_back(...)
   }
   if (result_tensor_types.size() == 1) {
     return result_ndarrays[0];
   } else {
     return ADT::Tuple(result_ndarrays);
   }
   ```
   

##########
File path: src/relay/backend/interpreter.cc
##########
@@ -359,32 +446,25 @@ class Interpreter : public ExprFunctor<ObjectRef(const 
Expr& n)>,
       setter(arg_counter + i, arr);
     };
 
-    auto ret_type = func->body->checked_type();
     size_t out_cnt = 0;
-    if (auto rtype = ret_type.as<TupleTypeNode>()) {
+    if (auto rtype = ftn->ret_type.as<TupleTypeNode>()) {
+      // TODO(mbs): Recursive flatten?

Review comment:
       Done. Nice tetris effect of cascading simplifications, we should do this 
everywhere.

##########
File path: src/relay/backend/interpreter.cc
##########
@@ -476,17 +552,19 @@ class Interpreter : public ExprFunctor<ObjectRef(const 
Expr& n)>,
     };
 
     Array<Shape> out_shapes;
-    auto ret_type = func->body->checked_type();
-    bool is_dyn = IsDynamic(ret_type);
+    bool is_dyn = IsDynamic(ftn->ret_type);
 
     if (is_dyn) {
-      ICHECK(func->HasNonzeroAttr(attr::kPrimitive));
-      out_shapes = ComputeDynamicShape(func, args);
+      ICHECK(prim_shape_fn_var.defined());
+      ICHECK(prim_shape_fn_states.defined());
+      out_shapes =
+          ComputeDynamicShape(prim_shape_fn_var, all_prim_shape_fn_vars, 
prim_shape_fn_states,
+                              num_shape_inputs, num_shape_outputs, args);
     }
 
-    PackedFunc packed_func = compiler_->JIT(CCacheKey(func, target_));
     TVMRetValue rv;
-    if (const TupleTypeNode* rtype = 
func->body->checked_type().as<TupleTypeNode>()) {
+    if (const TupleTypeNode* rtype = ftn->ret_type.as<TupleTypeNode>()) {
+      // TODO(mbs): Recursive flatten?

Review comment:
       done.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to