This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d2a14a6880 [BYOC] Switch TensorRT BYOC integration to 
IRModule-at-a-time using RelayToTIR hook (#11979)
d2a14a6880 is described below

commit d2a14a6880ee2e4520f3e9a55accb258e8725e65
Author: Mark Shields <[email protected]>
AuthorDate: Fri Jul 1 15:09:06 2022 -0700

    [BYOC] Switch TensorRT BYOC integration to IRModule-at-a-time using 
RelayToTIR hook (#11979)
    
    * [BYOC] Switch TensorRT BYOC integration to IRModule-at-a-time using 
RelayToTIR hook
    
    This does for the TensorRT integration what #11631 did for the CUTLASS 
integration.
    
    - All compilation options are captured within the attributes of a Target of
      kind "tensorrt" (instead of the "relay.ext.tensorrt.options" attribute in
      PassContext). This means all BYOC configurations options needed by 
Collage can
      be captured uniformly by a list-of-Targets. It also means RPC boundaries 
(as used
      internally at OctoML) only need to worry about maintaining the fidelity 
of the
      Target instance(s) rather than reaching into the PassContext.
    
    - Compilation is switched from function-at-a-time (relying on the 
TECompiler) to
      IRModule-at-a-time (using the RelayToTIR target-specific hook mechanism). 
Though
      not strictly necessary for Collage I want to check the path is now clear 
to
      deprecate the support for BYOC in TEComplier.
    
    - Get all the TensorRT tests going again, except for a few I've disabled 
with
      x-link to a new issue #11765. CAUTION: The TensorRT runtime is not 
supported in
      CI so many of these tests are cosmetic.
    
    - While trying to track down a 'free(): invalid pointer' error in 
test_tensorrt_int8_exp.py
      made the TensorRT allocs/frees more robust, but turns out its also broken 
in main.
      No harm leaving these changes in though.
    
    * - Lints
    
    * - Woops, fix test
    
    * - lints
    
    * - Use default tensorrt target if none given in targets list
    
    * - fix free error
    
    * - accidentally introduced 'transforms' namespace
    - can't use default Target("tensorrt") arg
    
    * - D'oh! Include ended up #if protected
    
    * - restore mark for test_dynamic_offload
    - handle missing runtime in versioning
    - turn test_maskrcnn_resnet50 back on now that we have the
      import-torch-first workaround.
    
    * - wibble
---
 include/tvm/runtime/module.h                       |   2 +-
 .../meta_schedule/testing/custom_builder_runner.py |   7 +-
 python/tvm/relay/op/contrib/tensorrt.py            | 191 +++++++--------
 src/relay/backend/contrib/codegen_c/codegen.cc     |  10 +-
 src/relay/backend/contrib/cutlass/codegen.cc       |  10 +-
 src/relay/backend/contrib/tensorrt/codegen.cc      | 265 ++++++++++++---------
 .../contrib/tensorrt/{target.cc => codegen.h}      |  23 +-
 src/relay/backend/contrib/tensorrt/target.cc       |  31 ++-
 src/relay/transforms/compiler_function_utils.cc    |  16 +-
 src/relay/transforms/compiler_function_utils.h     |  15 +-
 src/runtime/const_loader_module.cc                 |  24 +-
 src/runtime/contrib/json/json_runtime.h            |   2 +
 src/runtime/contrib/tensorrt/tensorrt_builder.cc   |  27 ++-
 src/runtime/contrib/tensorrt/tensorrt_builder.h    |  10 +-
 src/runtime/contrib/tensorrt/tensorrt_ops.cc       |   4 +-
 src/runtime/contrib/tensorrt/tensorrt_runtime.cc   |  14 +-
 src/target/metadata_module.cc                      |   2 -
 tests/python/contrib/test_tensorrt.py              | 172 ++++++-------
 tests/python/contrib/test_tensorrt_int8_exp.py     |  23 +-
 19 files changed, 491 insertions(+), 357 deletions(-)

diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index 31d05571ee..9d139c9fef 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -113,7 +113,7 @@ class Module : public ObjectRef {
 class TVM_DLL ModuleNode : public Object {
  public:
   /*! \brief virtual destructor */
-  virtual ~ModuleNode() {}
+  virtual ~ModuleNode() = default;
   /*!
    * \return The per module type key.
    * \note This key is used to for serializing custom modules.
diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py 
b/python/tvm/meta_schedule/testing/custom_builder_runner.py
index e203848c2c..1cfd4ab833 100644
--- a/python/tvm/meta_schedule/testing/custom_builder_runner.py
+++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py
@@ -85,11 +85,8 @@ def build_relay_with_tensorrt(
     from tvm.relay.op.contrib import tensorrt
     from tvm.runtime import Module
 
-    mod, config = tensorrt.partition_for_tensorrt(mod, params)
-    with PassContext(
-        opt_level=3,
-        config={"relay.ext.tensorrt.options": config},
-    ):
+    mod = tensorrt.partition_for_tensorrt(mod, params)
+    with PassContext(opt_level=3):
         result = relay_build(mod, target=target, target_host=None, 
params=params)
     assert isinstance(result, Module)
     return result
diff --git a/python/tvm/relay/op/contrib/tensorrt.py 
b/python/tvm/relay/op/contrib/tensorrt.py
index a69e2d4105..4008b0eb3f 100644
--- a/python/tvm/relay/op/contrib/tensorrt.py
+++ b/python/tvm/relay/op/contrib/tensorrt.py
@@ -33,6 +33,10 @@ from tvm.relay.op.contrib.register import 
register_pattern_table
 logger = logging.getLogger("TensorRT")
 
 
+def is_tensorrt_compiler_enabled() -> bool:
+    return tvm.get_global_func("relay.ext.tensorrt.is_runtime_enabled", True) 
is not None
+
+
 def is_tensorrt_runtime_enabled() -> bool:
     """Check if the TensorRT graph executor is present.
     Returns
@@ -40,118 +44,105 @@ def is_tensorrt_runtime_enabled() -> bool:
     ret: bool
         True if present, False if not.
     """
-    check_enabled = 
tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+    check_enabled = 
tvm.get_global_func("relay.ext.tensorrt.is_runtime_enabled", True)
     if check_enabled:
         return check_enabled()
     return False
 
 
+def get_tensorrt_target() -> tvm.target.Target:
+    """Returns the current Target, which must be of kind "tensorrt"."""
+    target = tvm.target.Target.current()
+    if target is None or target.kind.name != "tensorrt":
+        # Create the default target.
+        return tvm.target.Target("tensorrt")
+    return target
+
+
 def get_tensorrt_version() -> Tuple[int, int, int]:
-    """Gets the version of TensorRT that TVM is built against or is targeting.
+    """Returns the version of TensorRT to assume during compilation.
+    In order of preference this is taken from:
+     - The current "tensorrt" target's "tensorrt_version" attribute string.
+     - The version linked to the TVM runtime.
+     - (6, 0, 1)
 
     Returns
     -------
     ret: Tuple[int, int, int]
-        TensorRT version as a tuple of major, minor, and patch number. If TVM
-        is not built with TensorRT, the value set by set_tensorrt_version() is 
returned instead.
+        TensorRT version as a tuple of (major, minor, patch).
     """
-    pass_ctx = tvm.transform.PassContext.current()
-    if "relay.ext.tensorrt.options" in pass_ctx.config:
-        return 
tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version)  # type: 
ignore
-    return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())  # 
type: ignore
+    # cf logic in tensorrt/codegen.cc::SaveGlobalAttributes
+    # First check for version in target.
+    target = get_tensorrt_target()
+    version = target.attrs["tensorrt_version"]
+    if len(version) == 3:
+        return int(version[0]), int(version[1]), int(version[2])
+    assert len(version) == 0
+
+    # Next, ask runtime for its version.
+    if is_tensorrt_runtime_enabled():
+        get_version = tvm.get_global_func("relay.ext.tensorrt.get_version")
+        version = get_version()
+        assert len(version) == 3
+        return int(version[0]), int(version[1]), int(version[2])
+
+    # Finally, use default.
+    logger.warning(
+        "TVM was not built against TensorRT and no version was provided in the 
'tensorrt' target."
+        "Defaulting to 6.0.1."
+    )
+    return (6, 0, 1)
 
 
 def get_tensorrt_use_implicit_batch_mode() -> bool:
-    pass_ctx = tvm.transform.PassContext.current()
-    if "relay.ext.tensorrt.options" in pass_ctx.config:
-        return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch
-    logger.warning(
-        "PassContext has no relay.ext.tensorrt.options config, using default 
value "
-        "use_implicit_batch=True."
-    )
-    return True
+    """Returns the "use_implicit_batch" attribute of the current "tensorrt" 
target."""
+    target = get_tensorrt_target()
+    return target.attrs["use_implicit_batch"]
 
 
 def get_tensorrt_remove_no_mac_subgraphs() -> bool:
-    pass_ctx = tvm.transform.PassContext.current()
-    if "relay.ext.tensorrt.options" in pass_ctx.config:
-        return 
pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs
-    logger.warning(
-        "PassContext has no relay.ext.tensorrt.options config, using default 
value "
-        "remove_no_mac_subgraphs=False."
-    )
-    return False
+    """Returns the "remove_no_mac_subgraphs" attribute of the current 
"tensorrt" target."""
+    target = get_tensorrt_target()
+    return target.attrs["remove_no_mac_subgraphs"]
+
+
+def get_tensorrt_use_fp16() -> bool:
+    """Returns the "use_fp16" attribute of the current "tensorrt" target."""
+    target = get_tensorrt_target()
+    return target.attrs["use_fp16"]
 
 
 def partition_for_tensorrt(
     mod: tvm.IRModule,
     params: Optional[Dict[str, tvm.nd.NDArray]] = None,
-    version: Optional[Tuple[int, int, int]] = None,
-    use_implicit_batch: bool = True,
-    remove_no_mac_subgraphs: bool = False,
-    max_workspace_size: int = 1 << 30,
-    use_fp16: bool = False,
-    use_uint8: bool = False,
-) -> Tuple[tvm.IRModule, Dict[str, Any]]:
-    """Partition the graph greedily offloading supported operators to TensorRT.
+    # CAUTION: Can't use default Target("tensorrt") here since the target kind 
is only available
+    #          if is_tensorrt_compiler_enabled() == True.
+    target: Optional[tvm.target.Target] = None,
+) -> tvm.IRModule:
+    """Partition all functions in mod to greedily offload supported operators 
to TensorRT.
 
     Parameters
     ----------
     mod : tvm.IRModule
-        The module to run passes on.
+        The module to partition.
+    target : tvm.target.Target
+        A target of kind "tensorrt" describing additional partitioning and 
compilation options.
     params : Optional[Dict[str, tvm.nd.NDArray]]
         Constant input parameters.
-    version : Optional[Tuple[int, int, int]]
-        TensorRT version to target as tuple of (major, minor, patch). If TVM 
is compiled with
-        USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used 
instead.
-    use_implicit_batch : bool
-        Use TensorRT implicit batch mode (default true). Setting to false will 
enable explicit batch
-        mode which will widen supported operators to include those which 
modify the batch dimension,
-        but may reduce performance for some models.
-    remove_no_mac_subgraphs : bool
-        Removes subgraphs which have been partitioned for TensorRT if they do 
not have any
-        multiply-accumulate operations. The removed subgraphs will go through 
TVM's standard
-        compilation instead. Can improve performance.
-    max_workspace_size : int
-        How many bytes of workspace size to allow each subgraph to use for 
TensorRT engine creation.
-        See TensorRT documentation for more info.
-    use_fp16: bool
-        Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is 
required to be enabled
-        if FP16 inputs tensors and weights are used.
-        Note that TensorRT will still choose a higher-precision kernel if it 
results in overall
-        lower runtime, or if no low-precision implementation exists.
-    use_uint8: bool
-        Allows, TRT to automatically convert FP32 inputs to UINT8.
 
     Returns
     -------
-    mod_and_config : Tuple[tvm.IRModule, Dict[str, Any]]
-        A tuple of 1) annotated and partitioned module and 2) 
"relay.ext.tensorrt.options"
-        configuration which should be given to PassContext when building.
+    partitioned_mod : tvm.IRModule
+        The partitioned module.
 
     """
-    config: Dict[str, Any] = {
-        "use_implicit_batch": use_implicit_batch,
-        "max_workspace_size": max_workspace_size,
-        "remove_no_mac_subgraphs": remove_no_mac_subgraphs,
-        "use_fp16": use_fp16,
-        "use_uint8": use_uint8,
-    }
-    if version:
-        assert isinstance(version, tuple) and len(version) == 3
-        config["tensorrt_version"] = version
-    else:
-        linked_version = 
tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
-        if not linked_version:
-            logger.warning(
-                "TVM was not built against TensorRT and no version was 
provided to "
-                "partition_for_tensorrt. Defaulting to 6.0.1"
-            )
-            linked_version = (6, 0, 1)
-        config["tensorrt_version"] = linked_version
-
+    assert is_tensorrt_compiler_enabled(), "Can only partition for TensorRT if 
it is enabled"
     if params:
         mod["main"] = bind_params_by_name(mod["main"], params)
+    if target is None:
+        # Use a default target. The get_tensorrt_target() function will 
similarly create an
+        # equivalent default target when compilation continues after 
partitioning.
+        target = tvm.target.Target("tensorrt")
 
     seq = tvm.transform.Sequential(
         [
@@ -174,24 +165,27 @@ def partition_for_tensorrt(
             transform.InferType(),
         ]
     )
-    with tvm.transform.PassContext(opt_level=3, 
config={"relay.ext.tensorrt.options": config}):
+    with target:
         mod = seq(mod)
-        # TODO(mbs): Revisit
-        # mod = prune_tensorrt_subgraphs(mod)
-    return mod, config
+        mod = prune_tensorrt_subgraphs(mod)
+    return mod
 
 
 def is_supported_trt_type(typ: Union[tvm.ir.TensorType, tvm.ir.TupleType], 
op_name: str) -> bool:
     """Check whether a type is supported by TensorRT."""
-    supported_dtypes = ["float32", "float16"]
+    supported_dtypes = ["float32"]
+    if get_tensorrt_use_fp16():
+        supported_dtypes.append("float16")
     if isinstance(typ, tvm.ir.TensorType):
         if typ.dtype not in supported_dtypes:
-            logger.info(f"{op_name}: Only float32 and float16 tensor dtypes 
are supported.")
+            logger.info(f"{op_name}: Only {supported_dtypes} tensor dtypes are 
supported.")
             return False
-        # assumes dim 0 is for batch and can be dynamic
-        # TODO(mbs): But does this depend use_implicit_batch flag?
-        for dim_shape in typ.shape[1:]:
-            if isinstance(dim_shape, tvm.tir.expr.Any):
+        dims = typ.shape
+        if get_tensorrt_use_implicit_batch_mode():
+            # The first dimension can be Any.
+            dims = dims[1:]
+        for dim in dims:
+            if isinstance(dim, tvm.tir.expr.Any):
                 logger.info(f"{op_name}: Only statically known tensor shapes 
are supported.")
                 return False
     elif isinstance(typ, tvm.ir.TupleType):
@@ -241,13 +235,19 @@ CheckFunc = Callable[[Any, List[relay.expr.Expr], str], 
bool]
 
 
 def make_predicate(checker: CheckFunc) -> Callable[[relay.expr.Expr], bool]:
+    """Returns the pattern predicate which performs the standard checks, then 
invokes the
+    more primitive checker."""
+
     def predicate(expr: relay.expr.Expr) -> bool:
         op_name = get_op_name(expr)
         attrs = get_attrs(expr)
         args = get_args(expr)
         if not all([is_supported_trt_type(arg.checked_type, op_name) for arg 
in args]):
             return False
-        return checker(attrs, args, op_name)
+        if not checker(attrs, args, op_name):
+            return False
+        logger.info(f"{op_name}: Predicate passes")
+        return True
 
     return predicate
 
@@ -535,11 +535,16 @@ def concatenate_checker(
         if int(attrs.axis) == 0:
             logger.info(f"{op_name}: can't modify batch dimension.")
             return False
-        if isinstance(args[0], relay.Tuple):
-            for tuple_input in args[0].fields:
-                if isinstance(tuple_input, Constant):
-                    logger.info(f"{op_name}: can't concatenate tensors with 
constants.")
-                    return False
+
+    if not isinstance(args[0], relay.Tuple):
+        logger.info("f{op_name}: concatenate must be applied to a literal 
tuple")
+        return False
+
+    for tuple_input in args[0].fields:
+        if isinstance(tuple_input, Constant):
+            logger.info(f"{op_name}: can't concatenate tensors with 
constants.")
+            return False
+
     return True
 
 
diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc 
b/src/relay/backend/contrib/codegen_c/codegen.cc
index ee8724fe92..41f0a0a064 100644
--- a/src/relay/backend/contrib/codegen_c/codegen.cc
+++ b/src/relay/backend/contrib/codegen_c/codegen.cc
@@ -360,8 +360,8 @@ class CodegenCModule {
 };
 
 /*! \brief The actual translation pass. */
-transform::Pass CCompilerImpl() {
-  auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
+tvm::transform::Pass CCompilerImpl() {
+  auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& 
pass_ctx) {
     VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod);
     Target target = GetCCompilerTarget();
 
@@ -388,10 +388,10 @@ transform::Pass CCompilerImpl() {
   return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {});
 }
 
-transform::Pass CCompilerPass() {
+tvm::transform::Pass CCompilerPass() {
   return transform::Sequential(
-      
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), 
CCompilerImpl(),
-       transforms::MarkCompilerFunctionsAsExtern("ccompiler")});
+      
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), 
CCompilerImpl(),
+       transform::MarkCompilerFunctionsAsExtern("ccompiler")});
 }
 
 }  // namespace contrib
diff --git a/src/relay/backend/contrib/cutlass/codegen.cc 
b/src/relay/backend/contrib/cutlass/codegen.cc
index de2934173b..2e76ab1cbb 100644
--- a/src/relay/backend/contrib/cutlass/codegen.cc
+++ b/src/relay/backend/contrib/cutlass/codegen.cc
@@ -902,8 +902,8 @@ class CutlassModuleCodegen {
  * \brief A small shim to redirect to the 
'relay.ext.cutlass.compile_for_cutlass' Python
  * function which does the main CUTLASS training, c-code generation and 
compilation steps.
  */
-transform::Pass CompileForCutlassImpl() {
-  auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
+tvm::transform::Pass CompileForCutlassImpl() {
+  auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& 
pass_ctx) {
     VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod);
     const auto* pf = 
runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass");
     ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function";
@@ -926,10 +926,10 @@ runtime::Module CreateCSourceModule(const IRModule& mod) {
 
 
TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule);
 
-transform::Pass CompileForCutlass() {
+tvm::transform::Pass CompileForCutlass() {
   return transform::Sequential(
-      
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
-       CompileForCutlassImpl(), 
transforms::MarkCompilerFunctionsAsExtern("cutlass")});
+      {transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
+       CompileForCutlassImpl(), 
transform::MarkCompilerFunctionsAsExtern("cutlass")});
 }
 
 }  // namespace cutlass
diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc 
b/src/relay/backend/contrib/tensorrt/codegen.cc
index e08cd240d4..1c4a8d7806 100644
--- a/src/relay/backend/contrib/tensorrt/codegen.cc
+++ b/src/relay/backend/contrib/tensorrt/codegen.cc
@@ -29,6 +29,7 @@
 #include <string>
 #include <vector>
 
+#include "../../../transforms/compiler_function_utils.h"
 #include "../../utils.h"
 #include "../codegen_json/codegen_json.h"
 
@@ -39,36 +40,49 @@
 namespace tvm {
 namespace relay {
 namespace contrib {
+namespace tensorrt {
 
-/*! \brief Attributes to store the compiler options for TensorRT. */
-struct TensorRTCompilerConfigNode : public 
tvm::AttrsNode<TensorRTCompilerConfigNode> {
-  Array<Integer> tensorrt_version;
-  bool use_implicit_batch;
-  size_t max_workspace_size;
-  bool remove_no_mac_subgraphs;
-  bool use_fp16;
-  bool use_uint8;
-
-  TVM_DECLARE_ATTRS(TensorRTCompilerConfigNode, 
"ext.attrs.TensorRTCompilerConfigNode") {
-    TVM_ATTR_FIELD(tensorrt_version)
-        .describe("TensorRT version as (major, minor, patch).")
-        .set_default(Array<Integer>({6, 0, 1}));
-    TVM_ATTR_FIELD(use_implicit_batch).set_default(true);
-    TVM_ATTR_FIELD(max_workspace_size).set_default(size_t(1) << 30);
-    TVM_ATTR_FIELD(remove_no_mac_subgraphs).set_default(false);
-    TVM_ATTR_FIELD(use_fp16).set_default(false);
-    TVM_ATTR_FIELD(use_uint8).set_default(false);
-  }
-};
+/*!
+ * \brief Check whether TensorRT graph executor is enabled.
+ * \return True if enabled, False if not.
+ */
+inline constexpr bool IsRuntimeEnabled() {
+#if TVM_GRAPH_EXECUTOR_TENSORRT
+  return true;
+#else
+  return false;
+#endif  // TVM_GRAPH_EXECUTOR_TENSORRT
+}
 
-class TensorRTCompilerConfig : public Attrs {
- public:
-  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorRTCompilerConfig, Attrs,
-                                            TensorRTCompilerConfigNode);
-};
+TVM_REGISTER_GLOBAL("relay.ext.tensorrt.is_runtime_enabled").set_body_typed(IsRuntimeEnabled);
 
-TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode);
-TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.tensorrt.options", 
TensorRTCompilerConfig);
+/*!
+ * \brief Get TensorRT version that TVM is built against.
+ * \return Array of three integers for major, minor, and patch, or empty array 
if TensorRT graph
+ * runtime is not enabled.
+ */
+Array<Integer> GetVersion() {
+#if TVM_GRAPH_EXECUTOR_TENSORRT
+  return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), 
Integer(NV_TENSORRT_PATCH)};
+#else
+  return {};
+#endif  // TVM_GRAPH_EXECUTOR_TENSORRT
+}
+
+TVM_REGISTER_GLOBAL("relay.ext.tensorrt.get_version").set_body_typed(GetVersion);
+
+/*!
+ * \brief Returns the "tensorrt" Target instance to use for compilation.
+ */
+Target GetTensorRTTarget() {
+  Target target = Target::Current(/*allow_not_defined=*/true);
+  if (!target.defined() || target->kind->name != "tensorrt") {
+    // Since we allow partition_for_tensorrt to use the default "tensorrt" 
target, we should
+    // similarly allow the custom pass to execute without a specific 
"tensorrt" target in scope.
+    target = Target("tensorrt");
+  }
+  return target;
+}
 
 using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
 using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
@@ -87,6 +101,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor {
   explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer)
       : serializer_(serializer), node_(std::make_shared<JSONGraphNode>()) {}
 
+  // We'll need to implement these out-of-band since they use the serializer.
   void VisitExpr_(const ConstantNode* constant_node) final;
   void VisitExpr_(const CallNode* call_node) final;
 
@@ -190,6 +205,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor 
{
     extractor.Extract(const_cast<Object*>(attr_obj));
   }
 
+  /*! \brief The parent serializer for the overall TensorRT partition. */
   TensorRTJSONSerializer* serializer_;
   /*! \brief Accumulated translated arguments. */
   std::vector<JSONGraphNodeEntry> args_;
@@ -207,9 +223,10 @@ class CollectFromCompositeFunctionBody : public 
ExprVisitor {
  */
 class TensorRTJSONSerializer : public JSONSerializer {
  public:
-  TensorRTJSONSerializer(const std::string& symbol, const Expr& expr)
-      : JSONSerializer(symbol, expr) {}
+  TensorRTJSONSerializer(Target target, const std::string& symbol, const Expr& 
expr)
+      : JSONSerializer(symbol, expr), target_(std::move(target)) {}
 
+ private:
   using JSONSerializer::VisitExpr_;
 
   std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* call_node) final {
@@ -245,40 +262,62 @@ class TensorRTJSONSerializer : public JSONSerializer {
     node->CaptureAttrs(*collector.node_);
 
     // Capture global settings on the JSON node.
-    SaveGlobalAttributes(node);
+    // TODO(mbs): Why on every call?
+    SaveGlobalAttributes(node.get());
 
     VLOG(1) << name << " has " << node->GetInputs().size() << " inputs";
 
     return AddNode(node, GetRef<Expr>(call_node));
   }
 
-  static void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
-    auto ctx = transform::PassContext::Current();
-    auto cfg = 
ctx->GetConfig<TensorRTCompilerConfig>("relay.ext.tensorrt.options");
-    if (!cfg.defined()) {
-      cfg = AttrsWithDefaultValues<TensorRTCompilerConfig>();
+  static void SetAttr(JSONGraphNode* node, const std::string& key,
+                      std::vector<std::string> values) {
+    node->SetAttr(key, std::vector<dmlc::any>({std::move(values)}));
+  }
+
+  /*! \brief Capture the compilation options as attributes on \p node. */
+  void SaveGlobalAttributes(JSONGraphNode* node) {
+    {
+      // cf logic in tensorrt.py::get_tensorrt_version.
+      // First check for version in target.
+      Array<Integer> target_attr = 
target_->GetAttr<Array<Integer>>("tensorrt_version").value();
+      if (target_attr.empty()) {
+        // Next, ask runtime for its version.
+        target_attr = GetVersion();
+      }
+      if (target_attr.empty()) {
+        // Finally, use default.
+        target_attr = {6, 0, 1};
+      }
+      ICHECK_EQ(target_attr.size(), 3);
+      SetAttr(node, "tensorrt_version",
+              {std::to_string(target_attr[0]), std::to_string(target_attr[1]),
+               std::to_string(target_attr[2])});
+    }
+
+    {
+      Bool target_attr = target_->GetAttr<Bool>("use_implicit_batch").value();
+      SetAttr(node, "use_implicit_batch", 
{std::to_string(target_attr->value)});
+    }
+
+    {
+      Integer target_attr = 
target_->GetAttr<Integer>("max_workspace_size").value();
+      SetAttr(node, "max_workspace_size", 
{std::to_string(target_attr->value)});
+    }
+
+    {
+      Bool target_attr = target_->GetAttr<Bool>("use_fp16").value();
+      SetAttr(node, "use_fp16", {std::to_string(target_attr->value)});
+    }
+
+    {
+      Bool target_attr = target_->GetAttr<Bool>("use_uint8").value();
+      SetAttr(node, "use_uint8", {std::to_string(target_attr->value)});
     }
-    ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3);
-    std::vector<std::string> tensorrt_version = 
{std::to_string(cfg.value()->tensorrt_version[0]),
-                                                 
std::to_string(cfg.value()->tensorrt_version[1]),
-                                                 
std::to_string(cfg.value()->tensorrt_version[2])};
-    std::vector<std::string> use_implicit_batch = 
{std::to_string(cfg.value()->use_implicit_batch)};
-    std::vector<std::string> max_workspace_size = 
{std::to_string(cfg.value()->max_workspace_size)};
-    std::vector<std::string> use_fp16 = 
{std::to_string(cfg.value()->use_fp16)};
-    std::vector<std::string> use_uint8 = 
{std::to_string(cfg.value()->use_uint8)};
-    std::vector<dmlc::any> tensorrt_version_attr, use_implicit_batch_attr, 
max_workspace_size_attr,
-        use_fp16_attr, use_uint8_attr;
-    tensorrt_version_attr.emplace_back(tensorrt_version);
-    use_implicit_batch_attr.emplace_back(use_implicit_batch);
-    max_workspace_size_attr.emplace_back(max_workspace_size);
-    use_fp16_attr.emplace_back(use_fp16);
-    use_uint8_attr.emplace_back(use_uint8);
-    node->SetAttr("tensorrt_version", tensorrt_version_attr);
-    node->SetAttr("use_implicit_batch", use_implicit_batch_attr);
-    node->SetAttr("max_workspace_size", max_workspace_size_attr);
-    node->SetAttr("use_fp16", use_fp16_attr);
-    node->SetAttr("use_uint8", use_uint8_attr);
   }
+
+  /*! \brief The "tensorrt" Target guiding compilation. */
+  Target target_;
 };
 
 void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* 
constant_node) {
@@ -304,64 +343,74 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const 
CallNode* call_node) {
 }
 
 /*!
- * \brief Create a runtime module for TensorRT.
- * \param ref The ext_func Relay expression/module to be executed using extern 
ops.
- * \return A runtime module.
- */
-runtime::Module TensorRTCompiler(const ObjectRef& ref) {
-  ICHECK(ref->IsInstance<FunctionNode>()) << "The input ref is expected to be 
a Relay function.";
-  Function func = Downcast<Function>(ref);
-  std::string func_name = backend::GetExtSymbol(func);
-
-  VLOG(1) << "TensorRT partition:" << std::endl << PrettyPrint(func);
-  TensorRTJSONSerializer serializer(func_name, func);
-  serializer.serialize();
-  std::string graph_json = serializer.GetJSON();
-  VLOG(1) << "TensorRT JSON:" << std::endl << graph_json;
-
-  // Note that serializer.const_name_to_constant() is ignored. Instead the 
TECompiler invokes
-  // a callback which calls backend::UpdateConstants to capture the map before 
the function
-  // 'disappears' into lowered form, on the assumption the visit order and 
thus constant
-  // names match those generated by the JSONSerializer.
-
-  const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
-  ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create 
function.";
-  VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'";
-  runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names());
-  return lib;
-}
-
-TVM_REGISTER_GLOBAL("relay.ext.tensorrt").set_body_typed(TensorRTCompiler);
-
-/*!
- * \brief Check whether TensorRT graph executor is enabled.
- * \return True if enabled, False if not.
+ * \brief The main TensorRT compiler.
+ *
+ * TODO(mbs): Currently we create a \p TensorRTRuntimeModule for every 
function with
+ * Compiler="tensorrt" (ie for each partition). Since the TensorRT engine is 
only designed to
+ * handle a single entry point this is mostly sensible, however there are 
probably opportunities
+ * for more sharing between functions. However, note this means each call to a 
TensorRT-compiled
+ * function will require a linear scan of imported runtime modules to find the 
matching
+ * TensorRTRuntimeModule implementing it.
  */
-inline constexpr bool IsTensorRTRuntimeEnabled() {
-#if TVM_GRAPH_EXECUTOR_TENSORRT
-  return true;
-#else
-  return false;
-#endif  // TVM_GRAPH_EXECUTOR_TENSORRT
+tvm::transform::Pass CompileForTensorRTImpl() {
+  auto pass_func = [](IRModule mod, const tvm::transform::PassContext& 
pass_ctx) {
+    VLOG(1) << "CompileForTensorRT input:" << std::endl << PrettyPrint(mod);
+    Target target = GetTensorRTTarget();
+
+    const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
+    ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create 
function.";
+
+    // The accumulated external runtime modules.
+    Array<runtime::Module> external_mods =
+        
mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});
+    // The accumulated constant bindings.
+    Map<String, runtime::NDArray> const_name_to_constant =
+        mod->GetAttr<Map<String, 
runtime::NDArray>>(tvm::attr::kConstNameToConstant).value_or({});
+
+    for (const auto& kv : mod->functions) {
+      if (const auto* function_node = kv.second.as<FunctionNode>()) {
+        if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+          Optional<String> opt_compiler = 
function_node->GetAttr<String>(attr::kCompiler);
+          if (opt_compiler && opt_compiler.value() == "tensorrt") {
+            // Serialize the function to JSON.
+            TensorRTJSONSerializer serializer(target, kv.first->name_hint,
+                                              GetRef<Function>(function_node));
+            serializer.serialize();
+            std::string graph_json = serializer.GetJSON();
+            VLOG(1) << "TensorRT JSON for '" << kv.first->name_hint << "':" << 
std::endl
+                    << graph_json;
+
+            // Remember all the constant bindings.
+            for (const auto& kv2 : serializer.const_name_to_constant()) {
+              ICHECK_EQ(const_name_to_constant.count(kv2.first), 0);
+              VLOG(1) << "binding constant '" << kv2.first << "' for function 
'"
+                      << kv.first->name_hint << "'";
+              const_name_to_constant.Set(kv2.first, kv2.second);
+            }
+
+            // Create the actual runtime module.
+            runtime::Module runtime_mod =
+                (*pf)(kv.first->name_hint, graph_json, 
serializer.const_names());
+
+            // Remember the runtime module.
+            external_mods.push_back(runtime_mod);
+          }
+        }
+      }
+    }
+    return WithAttrs(mod, {{tvm::attr::kExternalMods, external_mods},
+                           {tvm::attr::kConstNameToConstant, 
const_name_to_constant}});
+  };
+  return tvm::transform::CreateModulePass(pass_func, 0, "CompileForTensorRT", 
{});
 }
 
-/*!
- * \brief Get TensorRT version that TVM is built against.
- * \return Array of three integers for major, minor, and patch, or empty array 
if TensorRT graph
- * runtime is not enabled.
- */
-Array<Integer> GetTensorRTVersion() {
-#if TVM_GRAPH_EXECUTOR_TENSORRT
-  return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), 
Integer(NV_TENSORRT_PATCH)};
-#else
-  return {};
-#endif  // TVM_GRAPH_EXECUTOR_TENSORRT
+tvm::transform::Pass CompileForTensorRT() {
+  return transform::Sequential(
+      
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"),
+       CompileForTensorRTImpl(), 
transform::MarkCompilerFunctionsAsExtern("tensorrt")});
 }
 
-TVM_REGISTER_GLOBAL("relay.op.is_tensorrt_runtime_enabled")
-    .set_body_typed(IsTensorRTRuntimeEnabled);
-TVM_REGISTER_GLOBAL("relay.op.get_tensorrt_version").set_body_typed(GetTensorRTVersion);
-
+}  // namespace tensorrt
 }  // namespace contrib
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/backend/contrib/tensorrt/target.cc 
b/src/relay/backend/contrib/tensorrt/codegen.h
similarity index 57%
copy from src/relay/backend/contrib/tensorrt/target.cc
copy to src/relay/backend/contrib/tensorrt/codegen.h
index 85d127ab71..813a866375 100644
--- a/src/relay/backend/contrib/tensorrt/target.cc
+++ b/src/relay/backend/contrib/tensorrt/codegen.h
@@ -18,25 +18,30 @@
  */
 
 /*!
- * \file src/relay/backend/contrib/tensorrt/target.cc
- * \brief Registers the "tensorrt" external codegen TargetKind.
+ * \file src/relay/backend/contrib/tensorrt/codegen.h
+ * \brief The 'custom' compilation pass for TensorRT (invoked by the 
RelayToTIRTargetHook pass).
  */
 
-#include <tvm/target/target.h>
+#ifndef TVM_RELAY_BACKEND_CONTRIB_TENSORRT_CODEGEN_H_
+#define TVM_RELAY_BACKEND_CONTRIB_TENSORRT_CODEGEN_H_
+
+#include <tvm/ir/transform.h>
 
 namespace tvm {
 namespace relay {
 namespace contrib {
+namespace tensorrt {
 
 /*!
- * \brief This external codegen target can offload compilation to the TensorRT 
compiler.
- *  - Patterns: python/tvm/relay/op/contrib/tensorrt.py
- *  - Custom compiler: src/relay/backend/contrib/tensorrt/codegen.cc
- *  - Runtime: src/runtime/contrib/tensorrt/ *.cc
+ * \brief Returns the pass which replaces all calls to "Primitive" functions 
with a "Compiler"
+ * attribute of "tensorrt" with calls to an extern which is implemented by a 
\p TensorRTRuntime
+ * runtime module added to the IRModule's "external_mods" attribute.
  */
-TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA)
-    .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true));
+transform::Pass CompileForTensorRT();
 
+}  // namespace tensorrt
 }  // namespace contrib
 }  // namespace relay
 }  // namespace tvm
+
+#endif  // TVM_RELAY_BACKEND_CONTRIB_TENSORRT_CODEGEN_H_
diff --git a/src/relay/backend/contrib/tensorrt/target.cc 
b/src/relay/backend/contrib/tensorrt/target.cc
index 85d127ab71..2e4581d30a 100644
--- a/src/relay/backend/contrib/tensorrt/target.cc
+++ b/src/relay/backend/contrib/tensorrt/target.cc
@@ -24,19 +24,46 @@
 
 #include <tvm/target/target.h>
 
+#include "./codegen.h"
+
 namespace tvm {
 namespace relay {
 namespace contrib {
+namespace tensorrt {
 
 /*!
  * \brief This external codegen target can offload compilation to the TensorRT 
compiler.
  *  - Patterns: python/tvm/relay/op/contrib/tensorrt.py
  *  - Custom compiler: src/relay/backend/contrib/tensorrt/codegen.cc
- *  - Runtime: src/runtime/contrib/tensorrt/ *.cc
+ *  - Runtime: src/runtime/contrib/tensorrt/...
  */
 TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA)
-    .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true));
+    .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
+    .set_attr<FTVMRelayToTIR>("RelayToTIR", CompileForTensorRT())
+    // A array of three integers given the major, minor, and patch numbers for 
the supported
+    // TensorRT compiler version. If empty will be auto-detected from linked 
library. Default empty.
+    .add_attr_option<Array<Integer>>("tensorrt_version", Array<Integer>())
+    // If true, the first tensor dimension for most operators is allowed to be 
Any and
+    // TensorRT will assume it represents a batch dimension only known at 
inference time.
+    // Fewer Relay operators are supported in implicit batch mode. Default 
true.
+    .add_attr_option<Bool>("use_implicit_batch", Bool(true))
+    // If true, excludes sub-graphs which do not have multiply-accumulate 
operations, even though
+    // TensorRT supports them. ad. This is a simple heuristic to optimize the 
partitioning between
+    // TensorRT and TVM. Not required if using Collage for partitioning. 
Defalut false.
+    .add_attr_option<Bool>("remove_no_mac_subgraphs", Bool(false))
+    // How many bytes of workspace size to allow each subgraph to use for 
TensorRT engine creation.
+    // Default 1G.
+    .add_attr_option<Integer>("max_workspace_size", Integer(1 << 30))
+    // If true, allows TensorRT to automatically convert float32 operations to 
float16. Must also be
+    // enabled if any float16 operations are in the model. Note that TensorRT 
may still choose a
+    // higher-precision kernel if it results in overall lower runtime, or if 
no low-precision
+    // implementation exists. Default false.
+    .add_attr_option<Bool>("use_fp16", Bool(false))
+    // If true, allows TensorRT to automatically convert float32 operations to 
uint8
+    // (aka quantized). Default false.
+    .add_attr_option<Bool>("use_uint8", Bool(false));
 
+}  // namespace tensorrt
 }  // namespace contrib
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/transforms/compiler_function_utils.cc 
b/src/relay/transforms/compiler_function_utils.cc
index 0df9f5ee29..1dafcd10a3 100644
--- a/src/relay/transforms/compiler_function_utils.cc
+++ b/src/relay/transforms/compiler_function_utils.cc
@@ -24,14 +24,13 @@
 
 #include "./compiler_function_utils.h"
 
-#include "../op/call/call.h"
 #include "tvm/relay/analysis.h"
 #include "tvm/relay/expr_functor.h"
 #include "tvm/relay/transform.h"
 
 namespace tvm {
 namespace relay {
-namespace transforms {
+namespace transform {
 namespace {
 
 /*!
@@ -211,8 +210,8 @@ GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const 
Function& function) {
   return global_var;
 }
 
-transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> 
cache,
-                                         std::string compiler_filter) {
+tvm::transform::Pass 
OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
+                                              std::string compiler_filter) {
   runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> 
pass_func =
       [cache = std::move(cache), compiler_filter = std::move(compiler_filter)](
           IRModule mod, transform::PassContext ctx) {
@@ -235,12 +234,13 @@ transform::Pass 
OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
 }
 
 // Any Java programmers in the house?
-transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string 
compiler_filter) {
+tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(
+    std::string compiler_filter) {
   return 
OutlineCompilerFunctions(std::make_shared<ExistingGlobalSymbolCache>(),
                                   std::move(compiler_filter));
 }
 
-transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
+tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string 
compiler_filter) {
   runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> 
pass_func =
       [compiler_filter = std::move(compiler_filter)](IRModule mod, 
transform::PassContext ctx) {
         VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << 
PrettyPrint(mod);
@@ -262,7 +262,7 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string 
compiler_filter) {
   return tvm::transform::CreateModulePass(pass_func, 0, 
"MarkCompilerFunctionsAsExtern", {});
 }
 
-transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars) {
+tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> 
global_vars) {
   runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> 
pass_func =
       [global_vars = std::move(global_vars)](IRModule mod, 
transform::PassContext ctx) {
         VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " << 
PrettyPrint(global_vars);
@@ -295,6 +295,6 @@ 
TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern")
 TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo")
     .set_body_typed(InlineCompilerFunctionsBoundTo);
 
-}  // namespace transforms
+}  // namespace transform
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/transforms/compiler_function_utils.h 
b/src/relay/transforms/compiler_function_utils.h
index aa98430318..f3499faec2 100644
--- a/src/relay/transforms/compiler_function_utils.h
+++ b/src/relay/transforms/compiler_function_utils.h
@@ -66,7 +66,7 @@
 
 namespace tvm {
 namespace relay {
-namespace transforms {
+namespace transform {
 
 /*!
  * \brief Abstract class representing a cache of unique global vars keyed by 
functions. This can
@@ -105,8 +105,8 @@ class ExistingGlobalSymbolCache : public GlobalSymbolCache {
  * If \p compiler_filter is non-empty only functions with that as their 
attribute value are
  * outlined.
  */
-transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> 
cache,
-                                         std::string compiler_filter = "");
+tvm::transform::Pass 
OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
+                                              std::string compiler_filter = 
"");
 
 /*!
  * \brief A pass to outline all let-bound and literal functions in direct call 
positions which have
@@ -119,7 +119,8 @@ transform::Pass 
OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
  * This pass may be useful for external codegen using the "RelayToTIR" custom 
pass mechanism
  * to prepare the IRModule before custom lowering.
  */
-transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string 
compiler_filter = "");
+tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(
+    std::string compiler_filter = "");
 
 /*!
  * \brief A pass to mark all global functions which have a "Compiler" 
attribute matching
@@ -132,7 +133,7 @@ transform::Pass 
OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
  * This pass may be useful for external codegen using the "RelayToTIR" custom 
pass mechanism to
  * cleanup the IRModule after custom lowering.
  */
-transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = 
"");
+tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter 
= "");
 
 /*!
  * \brief A pass to inline all global "Compiler" functions which are bound to 
a global var
@@ -142,9 +143,9 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string 
compiler_filter = "");
  * This pass may be useful for external codegen which needs to undo 
partitioning based on
  * properties of the entire partition.
  */
-transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars);
+tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> 
global_vars);
 
-}  // namespace transforms
+}  // namespace transform
 }  // namespace relay
 }  // namespace tvm
 
diff --git a/src/runtime/const_loader_module.cc 
b/src/runtime/const_loader_module.cc
index 2e91d26d5f..a8028e616c 100644
--- a/src/runtime/const_loader_module.cc
+++ b/src/runtime/const_loader_module.cc
@@ -51,15 +51,24 @@ class ConstLoaderModuleNode : public ModuleNode {
       const std::unordered_map<std::string, NDArray>& const_var_ndarray,
       const std::unordered_map<std::string, std::vector<std::string>>& 
const_vars_by_symbol)
       : const_var_ndarray_(const_var_ndarray), 
const_vars_by_symbol_(const_vars_by_symbol) {
+    VLOG(1) << "Creating ConstLoaderModule";
     // Only the related submodules are cached to reduce the number of runtime
     // symbol lookup for initialization. Otherwise, symbols/primitives in the
     // DSO module will also be cached but they never need to be initialized.
-    for (const auto& it : const_vars_by_symbol_) {
-      initialized_[it.first] = false;
+    for (const auto& kv : const_vars_by_symbol_) {
+      for (const auto& var : kv.second) {
+        VLOG(1) << "ConstLoaderModuleNode has constant '" << var << "' for 
function '" << kv.first
+                << "'";
+        ICHECK_GT(const_var_ndarray_.count(var), 0)
+            << "ConstLoaderModuleNode is missing entry for constant '" << var 
<< "' for function '"
+            << kv.first << "'";
+      }
+      initialized_[kv.first] = false;
     }
   }
 
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final {
+    VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")";
     // Initialize and memoize the module.
     // Usually, we have some warmup runs. The module initialization should be
     // done at this stage. Therefore, runtime overhead is not a concern.
@@ -88,11 +97,13 @@ class ConstLoaderModuleNode : public ModuleNode {
    */
   Array<NDArray> GetRequiredConstants(const std::string& symbol) {
     Array<NDArray> ret;
-    ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) << "No symbol is 
recorded for " << symbol;
+    ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U)
+        << "No constants known for function '" << symbol << "'";
     std::vector<std::string> vars = const_vars_by_symbol_[symbol];
-    for (const auto& it : vars) {
-      ICHECK_GT(const_var_ndarray_.count(it), 0U) << "Found not recorded 
constant variable: " << it;
-      ret.push_back(const_var_ndarray_[it]);
+    for (const auto& var : vars) {
+      ICHECK_GT(const_var_ndarray_.count(var), 0U)
+          << "No such constant variable '" << var << "' for function '" << 
symbol << "'";
+      ret.push_back(const_var_ndarray_[var]);
     }
     return ret;
   }
@@ -229,5 +240,6 @@ TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata")
     .set_body_typed(ConstLoaderModuleNode::LoadFromBinary);
 TVM_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader")
     .set_body_typed(ConstLoaderModuleNode::LoadFromBinary);
+
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/contrib/json/json_runtime.h 
b/src/runtime/contrib/json/json_runtime.h
index 355390765d..3a02202b87 100644
--- a/src/runtime/contrib/json/json_runtime.h
+++ b/src/runtime/contrib/json/json_runtime.h
@@ -54,6 +54,8 @@ class JSONRuntimeBase : public ModuleNode {
     LoadGraph(graph_json_);
   }
 
+  ~JSONRuntimeBase() override = default;
+
   const char* type_key() const override { return "json"; }  // May be 
overridden
 
   /*! \brief Initialize a specific json runtime. */
diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc 
b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
index 5f923667d0..436a6db4c8 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
@@ -45,10 +45,11 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
       max_workspace_size_(max_workspace_size),
       use_implicit_batch_(use_implicit_batch),
       use_fp16_(use_fp16),
-      batch_size_(batch_size) {
+      use_int8_(false),
+      batch_size_(batch_size),
+      calibrator_(calibrator) {
   // Create TRT builder and network.
   builder_ = nvinfer1::createInferBuilder(*logger);
-  use_int8_ = false;
 
 #if TRT_VERSION_GE(6, 0, 1)
   // Use INetworkV2.
@@ -58,8 +59,7 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
     flags = 0U;
     builder_->setMaxBatchSize(batch_size_);
   }
-  this->calibrator_ = calibrator;
-  if (calibrator != nullptr) {
+  if (calibrator_ != nullptr) {
     use_int8_ = true;
   }
   network_ = builder_->createNetworkV2(flags);
@@ -177,6 +177,7 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
 
   if (use_int8_) {
     config_->setFlag(nvinfer1::BuilderFlag::kINT8);
+    ICHECK(calibrator_);
     config_->setInt8Calibrator(calibrator_);
     LOG(INFO) << "config finishes setting up calibrator as INT8 mode ... ";
   }
@@ -210,6 +211,9 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
   nvinfer1::IExecutionContext* context = engine->createExecutionContext();
   CleanUp();
 
+  ICHECK(engine);
+  ICHECK(context);
+
   return {engine, context, network_input_names_, network_output_names_};
 }
 
@@ -254,18 +258,33 @@ nvinfer1::ITensor* 
TensorRTBuilder::GetInputAsTensor(const TensorRTOpInput& inpu
 }
 
 void TensorRTBuilder::CleanUp() {
+  VLOG(1) << "Destroying TensorRT network";
+  ICHECK(network_);
   network_->destroy();
+  network_ = nullptr;
+
 #if TRT_VERSION_GE(6, 0, 1)
+  VLOG(1) << "Destroying TensorRT config";
+  ICHECK(config_);
   config_->destroy();
+  config_ = nullptr;
 #endif
+
+  VLOG(1) << "Destroying TensorRT builder";
+  ICHECK(builder_);
   builder_->destroy();
+  builder_ = nullptr;
+
+  VLOG(1) << "Destroying TensorRT weights";
   for (auto weight : trt_weights_) {
+    ICHECK(weight.values);
     if (weight.type == nvinfer1::DataType::kFLOAT || weight.type == 
nvinfer1::DataType::kHALF) {
       delete[] static_cast<const float*>(weight.values);
     } else {
       delete[] static_cast<const uint16_t*>(weight.values);
     }
   }
+  trt_weights_.clear();
 }
 
 }  // namespace contrib
diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h 
b/src/runtime/contrib/tensorrt/tensorrt_builder.h
index 13a118340e..9bccc1ea48 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_builder.h
+++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h
@@ -48,8 +48,8 @@ using JSONGraphNodeEntry = 
tvm::runtime::json::JSONGraphNodeEntry;
  * perform inference.
  */
 struct TensorRTEngineAndContext {
-  nvinfer1::ICudaEngine* engine;
-  nvinfer1::IExecutionContext* context;
+  nvinfer1::ICudaEngine* engine = nullptr;
+  nvinfer1::IExecutionContext* context = nullptr;
   std::vector<std::string> inputs;
   std::vector<std::string> outputs;
 };
@@ -125,15 +125,15 @@ class TensorRTBuilder {
   std::unordered_map<int, std::vector<TensorRTOpInput>> node_output_map_;
 
   /*! \brief TensorRT builder. */
-  nvinfer1::IBuilder* builder_;
+  nvinfer1::IBuilder* builder_ = nullptr;
 
 #if TRT_VERSION_GE(6, 0, 1)
   /*! \brief TensorRT builder config. */
-  nvinfer1::IBuilderConfig* config_;
+  nvinfer1::IBuilderConfig* config_ = nullptr;
 #endif
 
   /*! \brief TensorRT network definition. */
-  nvinfer1::INetworkDefinition* network_;
+  nvinfer1::INetworkDefinition* network_ = nullptr;
 
   /*! \brief List of all weights held in memory. */
   std::vector<nvinfer1::Weights> trt_weights_;
diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc 
b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
index 3971081bf8..cd46967e53 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
@@ -67,7 +67,7 @@ nvinfer1::ITensor* 
TensorRTOpConverter::Transpose(TensorRTOpConverterParams* par
     // Batch dimension cannot be modified.
     ICHECK_EQ(input->getDimensions().nbDims, order.size() - 1);
     ICHECK_EQ(order[0], 0);
-    for (size_t i = 0; i < order.size(); ++i) {
+    for (size_t i = 0; i + 1 < order.size(); ++i) {
       perm.order[i] = order[i + 1] - 1;
     }
   } else {
@@ -880,7 +880,7 @@ class ConcatOpConverter : public TensorRTOpConverter {
     const int input_rank = params->inputs[0].tensor->getDimensions().nbDims;
     std::vector<nvinfer1::ITensor*> input_tensors;
     for (auto input : params->inputs) {
-      ICHECK(input.type == kTensor);
+      ICHECK_EQ(input.type, kTensor);
       ICHECK_EQ(input_rank, input.tensor->getDimensions().nbDims);
       input_tensors.push_back(input.tensor);
     }
diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc 
b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
index 18ffdbbbba..b51684b95e 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
@@ -138,13 +138,21 @@ class TensorRTRuntime : public JSONRuntimeBase {
   /*! \brief Destroy engines and contexts. */
   void DestroyEngines() {
     for (auto& it : trt_engine_cache_) {
+      VLOG(1) << "Destroying TensorRT context for function '" << 
it.first.first << "' (batch size "
+              << it.first.second << ")";
       it.second.context->destroy();
+      VLOG(1) << "Destroying TensorRT engine for function '" << it.first.first 
<< "' (batch size "
+              << it.first.second << ")";
       it.second.engine->destroy();
     }
     trt_engine_cache_.clear();
   }
 
-  ~TensorRTRuntime() { DestroyEngines(); }
+  ~TensorRTRuntime() override {
+    VLOG(1) << "Destroying TensorRT runtime";
+    DestroyEngines();
+    VLOG(1) << "Destroyed TensorRT runtime";
+  }
 
   /*! \brief Run inference using built engine. */
   void Run() override {
@@ -467,7 +475,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
   /*! \brief TensorRT logger. */
   TensorRTLogger logger_;
 
-#else
+#else   // TVM_GRAPH_EXECUTOR_TENSORRT
   void Run() override {
     LOG(FATAL) << "TensorRT runtime is not enabled. "
                << "Please build with USE_TENSORRT_RUNTIME.";
@@ -481,7 +489,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
   bool GetCachedEnginesFromDisk() { return false; }
 
   void CacheEngineToDisk() {}
-#endif
+#endif  // TVM_GRAPH_EXECUTOR_TENSORRT
 
   bool use_implicit_batch_;
 
diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc
index ec301d1081..e5ca82d5c0 100644
--- a/src/target/metadata_module.cc
+++ b/src/target/metadata_module.cc
@@ -215,8 +215,6 @@ runtime::Module CreateMetadataModule(
       String symbol = pf_sym();
       Array<String> variables = pf_var();
       for (size_t i = 0; i < variables.size(); i++) {
-        VLOG(1) << "From module of type '" << mod->type_key() << "' found 
const var '"
-                << variables[i] << "' for symbol '" << symbol << "'";
         symbol_const_vars.push_back(variables[i].operator std::string());
       }
       ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) << "Found duplicated 
symbol: " << symbol;
diff --git a/tests/python/contrib/test_tensorrt.py 
b/tests/python/contrib/test_tensorrt.py
index cecb64785a..9e39821fd3 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -14,31 +14,37 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import tvm.testing
+
 import numpy as np
 import pytest
 import itertools
+import logging
+from typing import Tuple
 
+try:
+    # See issue #9362.
+    import torch
+except:
+    pass
 
 import tvm
+import tvm.testing
 import tvm.relay.testing
 
 from tvm import relay
-from tvm.relay.op.contrib import tensorrt
-
 from tvm.relay import Any, GlobalVar
-
 from tvm.relay.expr_functor import ExprVisitor
-from typing import Tuple
 from tvm.contrib.download import download
 from tvm.relay.op.contrib import tensorrt
 
-
 SUPPORTED_DTYPES = ["float16", "float32"]
 
 has_tensorrt_codegen = pytest.mark.skipif(
-    not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT 
codegen not available"
+    not tensorrt.is_tensorrt_compiler_enabled(), reason="TensorRT codegen not 
available"
 )
+
+# CAUTION: Currently always false in CI since adds tens of minutes to test 
time and depends
+# on TensorRT installation. See https://github.com/apache/tvm/issues/11765
 has_tensorrt_runtime = pytest.mark.skipif(
     not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not 
available"
 )
@@ -72,7 +78,7 @@ def assert_result_dict_holds(result_dict, dtype="float16"):
                 tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=5e-3)
 
 
-def set_func_attr(func, compile_name, symbol_name):
+def set_outer_func_attr(func, compile_name, symbol_name):
     func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
     func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
     func = func.with_attr("Compiler", compile_name)
@@ -80,6 +86,12 @@ def set_func_attr(func, compile_name, symbol_name):
     return func
 
 
+def set_inner_func_attr(func, pattern_name, composite_name):
+    func = func.with_attr("PartitionedFromPattern", pattern_name)
+    func = func.with_attr("Composite", composite_name)
+    return func
+
+
 def run_and_verify_func(config, target="cuda", run_module=True, 
data_type="float32"):
     """Test a Relay func by compiling, running, and comparing TVM and TRT 
outputs.
 
@@ -110,34 +122,31 @@ def run_and_verify_func(config, target="cuda", 
run_module=True, data_type="float
 
     result_dict = dict()
     for mode in ["vm", "graph"]:
-        for mode in ["graph"]:
-            for use_trt in [True, False]:
-                mod = tvm.IRModule()
-                mod["main"] = f
-                result_key = mode + ("_trt" if use_trt else "")
-                if use_trt:
-                    mod = relay.transform.InferType()(mod)
-                    mod, config = tensorrt.partition_for_tensorrt(
-                        mod, params, use_fp16=data_type == "float16"
-                    )
-                    with tvm.transform.PassContext(
-                        opt_level=3, config={"relay.ext.tensorrt.options": 
config}
-                    ):
-                        func = relay.create_executor(
-                            mode, mod=mod, device=dev, target=target
-                        ).evaluate()
-                else:
-                    mod = relay.transform.InferType()(mod)
-                    with tvm.transform.PassContext(opt_level=3):
-                        func = relay.create_executor(
-                            mode, mod=mod, device=dev, target=target
-                        ).evaluate()
+        for use_trt in [True, False]:
+            mod = tvm.IRModule()
+            mod["main"] = f
+            result_key = mode + ("_trt" if use_trt else "")
+            if use_trt:
+                use_fp16 = data_type == "float16"
+                trt_target = tvm.target.Target(f"tensorrt 
-use_fp16={use_fp16}")
+                mod = relay.transform.InferType()(mod)
+                mod = tensorrt.partition_for_tensorrt(mod, params=params, 
target=trt_target)
+                with tvm.transform.PassContext(opt_level=3):
+                    func = relay.create_executor(
+                        mode, mod=mod, device=dev, target=[target, trt_target]
+                    ).evaluate()
+            else:
+                mod = relay.transform.InferType()(mod)
+                with tvm.transform.PassContext(opt_level=3):
+                    func = relay.create_executor(
+                        mode, mod=mod, device=dev, target=target
+                    ).evaluate()
 
-                if run_module:
-                    result_dict[result_key] = func(**input_dict, **params)
+            if run_module:
+                result_dict[result_key] = func(**input_dict, **params)
 
-                if run_module:
-                    assert_result_dict_holds(result_dict, data_type)
+            if run_module:
+                assert_result_dict_holds(result_dict, data_type)
 
 
 def test_tensorrt_simple(run_module):
@@ -163,10 +172,8 @@ def test_tensorrt_simple(run_module):
                 result_key = mode + ("_trt" if use_trt else "")
                 if use_trt:
                     mod = relay.transform.InferType()(mod)
-                    mod, config = tensorrt.partition_for_tensorrt(mod)
-                    with tvm.transform.PassContext(
-                        opt_level=3, config={"relay.ext.tensorrt.options": 
config}
-                    ):
+                    mod = tensorrt.partition_for_tensorrt(mod)
+                    with tvm.transform.PassContext(opt_level=3):
                         func = relay.create_executor(
                             mode, mod=mod, device=tvm.cuda(0), target="cuda"
                         ).evaluate()
@@ -212,9 +219,9 @@ def test_tensorrt_not_compatible(run_module):
     f = relay.Function([x], out)
     mod = tvm.IRModule()
     mod["main"] = f
-    mod, config = tensorrt.partition_for_tensorrt(mod)
+    mod = tensorrt.partition_for_tensorrt(mod)
     for mode in ["graph", "vm"]:
-        with tvm.transform.PassContext(opt_level=3, 
config={"relay.ext.tensorrt.options": config}):
+        with tvm.transform.PassContext(opt_level=3):
             func = relay.create_executor(
                 mode, mod=mod, device=tvm.cuda(0), target="cuda"
             ).evaluate()
@@ -622,26 +629,18 @@ class AreOpsOnGraph(ExprVisitor):
 
 
 def are_ops_on_trt(mod, op_list):
+    op_on_trt = False
+    op_on_tvm = False
     for subgraph in mod.get_global_vars():
         name = subgraph.name_hint
-        op_on_trt = False
-        op_on_tvm = True
-        if name == "main":
-            op_on_tvm = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
-        elif mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt":
-            op_on_trt = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
+        if mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt":
+            op_on_trt |= 
AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
         else:
-            op_on_tvm &= 
AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
-
-        if not op_on_trt or op_on_tvm:
-            return False
+            op_on_tvm |= 
AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
 
-    return True
+    return op_on_trt and not op_on_tvm
 
 
[email protected](
-    reason=("Currently failing test.  See tracking issue 
https://github.com/apache/tvm/issues/8901";)
-)
 def test_dynamic_reshape(run_module):
     def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt):
         result_arr = [{} for _ in range(len(x_data_list))]
@@ -652,9 +651,9 @@ def test_dynamic_reshape(run_module):
             mod = tvm.IRModule()
             mod["main"] = f
             if use_trt:
-                mod, _ = tensorrt.partition_for_tensorrt(
-                    mod, params={}, remove_no_mac_subgraphs=False
-                )
+                logging.info("Before partitioning:\n%s", mod)
+                mod = tensorrt.partition_for_tensorrt(mod)
+                logging.info("After partitioning:\n%s", mod)
                 assert are_ops_on_trt(mod, op_list=["reshape"]) == 
should_offload_to_trt
             if run_module:
                 with relay.build_config(opt_level=3):
@@ -1051,6 +1050,7 @@ def test_multiple_outputs(run_module):
         run_and_verify_func(get_graph(d_type=type), run_module=run_module, 
data_type=type)
 
 
[email protected](reason=("Fails assert_allclose. See 
https://github.com/apache/tvm/issues/11765";))
 def test_conv3d(run_module):
     def get_graph(
         x_shape=(1, 24, 8, 8, 8),
@@ -1143,11 +1143,7 @@ def test_conv3d_transpose(run_module):
     )
 
 
[email protected](
-    reason=("Currently failing test.  See tracking issue 
https://github.com/apache/tvm/issues/8901";)
-)
 @has_tensorrt_codegen
[email protected]_cuda
 def test_dynamic_offload():
     """
     This test checks for proper dynamic offloading of relay graphs. An 
addition between
@@ -1161,24 +1157,29 @@ def test_dynamic_offload():
 
     x = relay.var("x", shape=(data_shape[0], data_shape[1], Any(), Any()), 
dtype="float32")
     y = relay.var("y", shape=(data_shape), dtype="float32")
-    kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
+    kernel = relay.const(np.random.rand(*k_shape).astype("float32"))
 
     def get_expected():
         # Create a nested TRT function that matches the expected output
         mod = tvm.IRModule()
-        var1 = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32")
-        kernel_trt = relay.var("tensorrt_0_i1", shape=(k_shape), 
dtype="float32")
-        out1 = relay.nn.conv2d(var1, kernel_trt, channels=k_shape[0], 
kernel_size=k_shape[2:4])
-        f1 = GlobalVar("tvmgen_default_tensorrt_0")
-        func = relay.Function([var1, kernel_trt], out1)
-        func = set_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0")
-        mod[f1] = func
+        outer_var = relay.var("tensorrt_0_i0", shape=(data_shape), 
dtype="float32")
+        inner_var = relay.var("FunctionVar_0_0", shape=(data_shape), 
dtype="float32")
+        inner_body = relay.nn.conv2d(
+            inner_var, kernel, channels=k_shape[0], kernel_size=k_shape[2:4]
+        )
+        inner_func = relay.Function([inner_var], inner_body)
+        inner_func = set_inner_func_attr(inner_func, "nn.conv2d_", 
"tensorrt.nn.conv2d")
+        outer_body = inner_func(outer_var)
+        outer_func = relay.Function([outer_var], outer_body)
+        outer_func = set_outer_func_attr(outer_func, "tensorrt", 
"tvmgen_default_tensorrt_main_0")
+        gv = GlobalVar("tvmgen_default_tensorrt_main_0")
+        mod[gv] = outer_func
         mod = relay.transform.InferType()(mod)
 
         # Create the main function
         out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], 
kernel_size=k_shape[2:4])
-        out = relay.add(out1, f1(y, kernel))
-        f = relay.Function([x, y, kernel], out)
+        out = relay.add(out1, gv(y))
+        f = relay.Function([x, y], out)
         mod["main"] = f
         mod = relay.transform.InferType()(mod)
         return mod
@@ -1187,13 +1188,13 @@ def test_dynamic_offload():
     out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], 
kernel_size=k_shape[2:4])
     out2 = relay.nn.conv2d(y, kernel, channels=k_shape[0], 
kernel_size=k_shape[2:4])
     out = relay.add(out1, out2)
-    f = relay.Function([x, y, kernel], out)
+    f = relay.Function([x, y], out)
 
     # Pass the function to TRT compilation
     mod = tvm.IRModule()
     mod["main"] = f
     mod = relay.transform.InferType()(mod)
-    mod_trt, config = tensorrt.partition_for_tensorrt(mod, params={})
+    mod_trt = tensorrt.partition_for_tensorrt(mod)
 
     # Get the expected relay graph and compare
     mod_exp = get_expected()
@@ -1212,7 +1213,7 @@ def test_tensorrt_dynamic_batch(run_module):
         mod = tvm.IRModule()
         mod["main"] = f
         if use_trt:
-            mod, _ = tensorrt.partition_for_tensorrt(mod)
+            mod = tensorrt.partition_for_tensorrt(mod)
 
         if run_module:
             with relay.build_config(opt_level=3):
@@ -1242,17 +1243,17 @@ def test_tensorrt_dynamic_batch_conv(run_module):
             f = relay.Function([x, kernel], out)
             mod = tvm.IRModule()
             mod["main"] = f
+            trt_target = tvm.target.Target(f"tensorrt 
-use_implicit_batch={use_implicit_batch}")
             if use_trt:
-                mod, config = tensorrt.partition_for_tensorrt(
-                    mod, params, use_implicit_batch=use_implicit_batch
-                )
+                mod = tensorrt.partition_for_tensorrt(mod, params=params, 
target=trt_target)
             if run_module:
                 for target in ["llvm", "cuda"]:
-                    with tvm.transform.PassContext(
-                        opt_level=3, config={"relay.ext.tensorrt.options": 
config}
-                    ):
+                    targets = [target]
+                    if use_trt:
+                        targets.append(trt_target)
+                    with tvm.transform.PassContext(opt_level=3):
                         func = relay.create_executor(
-                            "vm", mod=mod, device=tvm.device(target), 
target=target
+                            "vm", mod=mod, device=tvm.device(target), 
target=targets
                         ).evaluate()
                     for i, batch_size in enumerate(batches_to_test):
                         result_arr[i][target][use_trt] = 
func(x_data[:batch_size, ...], **params)
@@ -1281,9 +1282,11 @@ def test_maskrcnn_resnet50(run_module) -> None:
         input_name = "input0"
         shape_list = [(input_name, input_shape)]
         mod, params = relay.frontend.from_pytorch(traced_module, shape_list)
-        mod, config = tensorrt.partition_for_tensorrt(mod, params, 
remove_no_mac_subgraphs=True)
+        trt_target = tvm.target.Target("tensorrt 
-remove_no_mac_subgraphs=True")
+        mod = tensorrt.partition_for_tensorrt(mod, params=params, 
target=trt_target)
+        targets = [target, trt_target]
         with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["FoldScaleAxis"]):
-            vm_trt_exec = relay.vm.compile(mod, target=target, params=params)
+            vm_trt_exec = relay.vm.compile(mod, target=targets, params=params)
 
         return vm_trt_exec
 
@@ -1381,7 +1384,7 @@ def test_empty_subgraph(run_module):
     var1 = relay.var("tensorrt_0_i0", shape=(x_shape), dtype="float32")
     f1 = GlobalVar("tensorrt_0")
     func = relay.Function([var1], var1)
-    func = set_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0")
+    func = set_outer_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0")
     mod[f1] = func
     mod = relay.transform.InferType()(mod)
 
@@ -1402,4 +1405,5 @@ def test_empty_subgraph(run_module):
 
 
 if __name__ == "__main__":
+    logging.basicConfig(level=logging.INFO)
     tvm.testing.main()
diff --git a/tests/python/contrib/test_tensorrt_int8_exp.py 
b/tests/python/contrib/test_tensorrt_int8_exp.py
index 84360e92d3..304d9a095e 100644
--- a/tests/python/contrib/test_tensorrt_int8_exp.py
+++ b/tests/python/contrib/test_tensorrt_int8_exp.py
@@ -18,8 +18,14 @@ import pytest
 import os
 import numpy as np
 
+try:
+    # See issue #9362.
+    import torch
+except:
+    pass
+
 import tvm
-import tvm.relay.testing
+import tvm.testing
 from tvm import relay
 from tvm.contrib.download import download_testdata
 from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
@@ -31,9 +37,10 @@ def skip_codegen_test():
     if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist:
         print("Skip because CUDA is not enabled.")
         return True
-    if not tvm.get_global_func("relay.ext.tensorrt", True):
-        print("Skip because TensorRT codegen is not available.")
+    if not tensorrt.is_tensorrt_compiler_enabled():
+        print("Skip because TensorRT compiler is not available.")
         return True
+    print("TensorRT compiler is available!")
     return False
 
 
@@ -44,6 +51,7 @@ def skip_runtime_test():
     if not tensorrt.is_tensorrt_runtime_enabled():
         print("Skip because TensorRT runtime is not available.")
         return True
+    print("TensorRT runtime is available!")
     return False
 
 
@@ -102,12 +110,11 @@ def test_trt_int8():
 
     # compile the model
     target = "cuda"
-    dev = tvm.cuda(1)
-    mod, config = partition_for_tensorrt(mod, params)
-    with tvm.transform.PassContext(opt_level=3, 
config={"relay.ext.tensorrt.options": config}):
+    dev = tvm.cuda()
+    mod = partition_for_tensorrt(mod, params)
+    with tvm.transform.PassContext(opt_level=3):
         lib = relay.build(mod, target=target, params=params)
 
-    dtype = "float32"
     gen_module = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
 
     num_cali_int8 = int(os.environ["TENSORRT_NUM_CALI_INT8"])
@@ -146,4 +153,4 @@ def test_trt_int8():
 
 
 if __name__ == "__main__":
-    pytest.main([__file__])
+    tvm.testing.main()

Reply via email to