This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d2a14a6880 [BYOC] Switch TensorRT BYOC integration to
IRModule-at-a-time using RelayToTIR hook (#11979)
d2a14a6880 is described below
commit d2a14a6880ee2e4520f3e9a55accb258e8725e65
Author: Mark Shields <[email protected]>
AuthorDate: Fri Jul 1 15:09:06 2022 -0700
[BYOC] Switch TensorRT BYOC integration to IRModule-at-a-time using
RelayToTIR hook (#11979)
* [BYOC] Switch TensorRT BYOC integration to IRModule-at-a-time using
RelayToTIR hook
This does for the TensorRT integration what #11631 did for the CUTLASS
integration.
- All compilation options are captured within the attributes of a Target of
kind "tensorrt" (instead of the "relay.ext.tensorrt.options" attribute in
PassContext). This means all BYOC configurations options needed by
Collage can
be captured uniformly by a list-of-Targets. It also means RPC boundaries
(as used
internally at OctoML) only need to worry about maintaining the fidelity
of the
Target instance(s) rather than reaching into the PassContext.
- Compilation is switched from function-at-a-time (relying on the
TECompiler) to
IRModule-at-a-time (using the RelayToTIR target-specific hook mechanism).
Though
not strictly necessary for Collage I want to check the path is now clear
to
deprecate the support for BYOC in TEComplier.
- Get all the TensorRT tests going again, except for a few I've disabled
with
x-link to a new issue #11765. CAUTION: The TensorRT runtime is not
supported in
CI so many of these tests are cosmetic.
- While trying to track down a 'free(): invalid pointer' error in
test_tensorrt_int8_exp.py
made the TensorRT allocs/frees more robust, but turns out its also broken
in main.
No harm leaving these changes in though.
* - Lints
* - Woops, fix test
* - lints
* - Use default tensorrt target if none given in targets list
* - fix free error
* - accidentally introduced 'transforms' namespace
- can't use default Target("tensorrt") arg
* - D'oh! Include ended up #if protected
* - restore mark for test_dynamic_offload
- handle missing runtime in versioning
- turn test_maskrcnn_resnet50 back on now that we have the
import-torch-first workaround.
* - wibble
---
include/tvm/runtime/module.h | 2 +-
.../meta_schedule/testing/custom_builder_runner.py | 7 +-
python/tvm/relay/op/contrib/tensorrt.py | 191 +++++++--------
src/relay/backend/contrib/codegen_c/codegen.cc | 10 +-
src/relay/backend/contrib/cutlass/codegen.cc | 10 +-
src/relay/backend/contrib/tensorrt/codegen.cc | 265 ++++++++++++---------
.../contrib/tensorrt/{target.cc => codegen.h} | 23 +-
src/relay/backend/contrib/tensorrt/target.cc | 31 ++-
src/relay/transforms/compiler_function_utils.cc | 16 +-
src/relay/transforms/compiler_function_utils.h | 15 +-
src/runtime/const_loader_module.cc | 24 +-
src/runtime/contrib/json/json_runtime.h | 2 +
src/runtime/contrib/tensorrt/tensorrt_builder.cc | 27 ++-
src/runtime/contrib/tensorrt/tensorrt_builder.h | 10 +-
src/runtime/contrib/tensorrt/tensorrt_ops.cc | 4 +-
src/runtime/contrib/tensorrt/tensorrt_runtime.cc | 14 +-
src/target/metadata_module.cc | 2 -
tests/python/contrib/test_tensorrt.py | 172 ++++++-------
tests/python/contrib/test_tensorrt_int8_exp.py | 23 +-
19 files changed, 491 insertions(+), 357 deletions(-)
diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index 31d05571ee..9d139c9fef 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -113,7 +113,7 @@ class Module : public ObjectRef {
class TVM_DLL ModuleNode : public Object {
public:
/*! \brief virtual destructor */
- virtual ~ModuleNode() {}
+ virtual ~ModuleNode() = default;
/*!
* \return The per module type key.
* \note This key is used to for serializing custom modules.
diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py
b/python/tvm/meta_schedule/testing/custom_builder_runner.py
index e203848c2c..1cfd4ab833 100644
--- a/python/tvm/meta_schedule/testing/custom_builder_runner.py
+++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py
@@ -85,11 +85,8 @@ def build_relay_with_tensorrt(
from tvm.relay.op.contrib import tensorrt
from tvm.runtime import Module
- mod, config = tensorrt.partition_for_tensorrt(mod, params)
- with PassContext(
- opt_level=3,
- config={"relay.ext.tensorrt.options": config},
- ):
+ mod = tensorrt.partition_for_tensorrt(mod, params)
+ with PassContext(opt_level=3):
result = relay_build(mod, target=target, target_host=None,
params=params)
assert isinstance(result, Module)
return result
diff --git a/python/tvm/relay/op/contrib/tensorrt.py
b/python/tvm/relay/op/contrib/tensorrt.py
index a69e2d4105..4008b0eb3f 100644
--- a/python/tvm/relay/op/contrib/tensorrt.py
+++ b/python/tvm/relay/op/contrib/tensorrt.py
@@ -33,6 +33,10 @@ from tvm.relay.op.contrib.register import
register_pattern_table
logger = logging.getLogger("TensorRT")
+def is_tensorrt_compiler_enabled() -> bool:
+ return tvm.get_global_func("relay.ext.tensorrt.is_runtime_enabled", True)
is not None
+
+
def is_tensorrt_runtime_enabled() -> bool:
"""Check if the TensorRT graph executor is present.
Returns
@@ -40,118 +44,105 @@ def is_tensorrt_runtime_enabled() -> bool:
ret: bool
True if present, False if not.
"""
- check_enabled =
tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
+ check_enabled =
tvm.get_global_func("relay.ext.tensorrt.is_runtime_enabled", True)
if check_enabled:
return check_enabled()
return False
+def get_tensorrt_target() -> tvm.target.Target:
+ """Returns the current Target, which must be of kind "tensorrt"."""
+ target = tvm.target.Target.current()
+ if target is None or target.kind.name != "tensorrt":
+ # Create the default target.
+ return tvm.target.Target("tensorrt")
+ return target
+
+
def get_tensorrt_version() -> Tuple[int, int, int]:
- """Gets the version of TensorRT that TVM is built against or is targeting.
+ """Returns the version of TensorRT to assume during compilation.
+ In order of preference this is taken from:
+ - The current "tensorrt" target's "tensorrt_version" attribute string.
+ - The version linked to the TVM runtime.
+ - (6, 0, 1)
Returns
-------
ret: Tuple[int, int, int]
- TensorRT version as a tuple of major, minor, and patch number. If TVM
- is not built with TensorRT, the value set by set_tensorrt_version() is
returned instead.
+ TensorRT version as a tuple of (major, minor, patch).
"""
- pass_ctx = tvm.transform.PassContext.current()
- if "relay.ext.tensorrt.options" in pass_ctx.config:
- return
tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version) # type:
ignore
- return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")()) #
type: ignore
+ # cf logic in tensorrt/codegen.cc::SaveGlobalAttributes
+ # First check for version in target.
+ target = get_tensorrt_target()
+ version = target.attrs["tensorrt_version"]
+ if len(version) == 3:
+ return int(version[0]), int(version[1]), int(version[2])
+ assert len(version) == 0
+
+ # Next, ask runtime for its version.
+ if is_tensorrt_runtime_enabled():
+ get_version = tvm.get_global_func("relay.ext.tensorrt.get_version")
+ version = get_version()
+ assert len(version) == 3
+ return int(version[0]), int(version[1]), int(version[2])
+
+ # Finally, use default.
+ logger.warning(
+ "TVM was not built against TensorRT and no version was provided in the
'tensorrt' target."
+ "Defaulting to 6.0.1."
+ )
+ return (6, 0, 1)
def get_tensorrt_use_implicit_batch_mode() -> bool:
- pass_ctx = tvm.transform.PassContext.current()
- if "relay.ext.tensorrt.options" in pass_ctx.config:
- return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch
- logger.warning(
- "PassContext has no relay.ext.tensorrt.options config, using default
value "
- "use_implicit_batch=True."
- )
- return True
+ """Returns the "use_implicit_batch" attribute of the current "tensorrt"
target."""
+ target = get_tensorrt_target()
+ return target.attrs["use_implicit_batch"]
def get_tensorrt_remove_no_mac_subgraphs() -> bool:
- pass_ctx = tvm.transform.PassContext.current()
- if "relay.ext.tensorrt.options" in pass_ctx.config:
- return
pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs
- logger.warning(
- "PassContext has no relay.ext.tensorrt.options config, using default
value "
- "remove_no_mac_subgraphs=False."
- )
- return False
+ """Returns the "remove_no_mac_subgraphs" attribute of the current
"tensorrt" target."""
+ target = get_tensorrt_target()
+ return target.attrs["remove_no_mac_subgraphs"]
+
+
+def get_tensorrt_use_fp16() -> bool:
+ """Returns the "use_fp16" attribute of the current "tensorrt" target."""
+ target = get_tensorrt_target()
+ return target.attrs["use_fp16"]
def partition_for_tensorrt(
mod: tvm.IRModule,
params: Optional[Dict[str, tvm.nd.NDArray]] = None,
- version: Optional[Tuple[int, int, int]] = None,
- use_implicit_batch: bool = True,
- remove_no_mac_subgraphs: bool = False,
- max_workspace_size: int = 1 << 30,
- use_fp16: bool = False,
- use_uint8: bool = False,
-) -> Tuple[tvm.IRModule, Dict[str, Any]]:
- """Partition the graph greedily offloading supported operators to TensorRT.
+ # CAUTION: Can't use default Target("tensorrt") here since the target kind
is only available
+ # if is_tensorrt_compiler_enabled() == True.
+ target: Optional[tvm.target.Target] = None,
+) -> tvm.IRModule:
+ """Partition all functions in mod to greedily offload supported operators
to TensorRT.
Parameters
----------
mod : tvm.IRModule
- The module to run passes on.
+ The module to partition.
+ target : tvm.target.Target
+ A target of kind "tensorrt" describing additional partitioning and
compilation options.
params : Optional[Dict[str, tvm.nd.NDArray]]
Constant input parameters.
- version : Optional[Tuple[int, int, int]]
- TensorRT version to target as tuple of (major, minor, patch). If TVM
is compiled with
- USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used
instead.
- use_implicit_batch : bool
- Use TensorRT implicit batch mode (default true). Setting to false will
enable explicit batch
- mode which will widen supported operators to include those which
modify the batch dimension,
- but may reduce performance for some models.
- remove_no_mac_subgraphs : bool
- Removes subgraphs which have been partitioned for TensorRT if they do
not have any
- multiply-accumulate operations. The removed subgraphs will go through
TVM's standard
- compilation instead. Can improve performance.
- max_workspace_size : int
- How many bytes of workspace size to allow each subgraph to use for
TensorRT engine creation.
- See TensorRT documentation for more info.
- use_fp16: bool
- Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is
required to be enabled
- if FP16 inputs tensors and weights are used.
- Note that TensorRT will still choose a higher-precision kernel if it
results in overall
- lower runtime, or if no low-precision implementation exists.
- use_uint8: bool
- Allows, TRT to automatically convert FP32 inputs to UINT8.
Returns
-------
- mod_and_config : Tuple[tvm.IRModule, Dict[str, Any]]
- A tuple of 1) annotated and partitioned module and 2)
"relay.ext.tensorrt.options"
- configuration which should be given to PassContext when building.
+ partitioned_mod : tvm.IRModule
+ The partitioned module.
"""
- config: Dict[str, Any] = {
- "use_implicit_batch": use_implicit_batch,
- "max_workspace_size": max_workspace_size,
- "remove_no_mac_subgraphs": remove_no_mac_subgraphs,
- "use_fp16": use_fp16,
- "use_uint8": use_uint8,
- }
- if version:
- assert isinstance(version, tuple) and len(version) == 3
- config["tensorrt_version"] = version
- else:
- linked_version =
tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
- if not linked_version:
- logger.warning(
- "TVM was not built against TensorRT and no version was
provided to "
- "partition_for_tensorrt. Defaulting to 6.0.1"
- )
- linked_version = (6, 0, 1)
- config["tensorrt_version"] = linked_version
-
+ assert is_tensorrt_compiler_enabled(), "Can only partition for TensorRT if
it is enabled"
if params:
mod["main"] = bind_params_by_name(mod["main"], params)
+ if target is None:
+ # Use a default target. The get_tensorrt_target() function will
similarly create an
+ # equivalent default target when compilation continues after
partitioning.
+ target = tvm.target.Target("tensorrt")
seq = tvm.transform.Sequential(
[
@@ -174,24 +165,27 @@ def partition_for_tensorrt(
transform.InferType(),
]
)
- with tvm.transform.PassContext(opt_level=3,
config={"relay.ext.tensorrt.options": config}):
+ with target:
mod = seq(mod)
- # TODO(mbs): Revisit
- # mod = prune_tensorrt_subgraphs(mod)
- return mod, config
+ mod = prune_tensorrt_subgraphs(mod)
+ return mod
def is_supported_trt_type(typ: Union[tvm.ir.TensorType, tvm.ir.TupleType],
op_name: str) -> bool:
"""Check whether a type is supported by TensorRT."""
- supported_dtypes = ["float32", "float16"]
+ supported_dtypes = ["float32"]
+ if get_tensorrt_use_fp16():
+ supported_dtypes.append("float16")
if isinstance(typ, tvm.ir.TensorType):
if typ.dtype not in supported_dtypes:
- logger.info(f"{op_name}: Only float32 and float16 tensor dtypes
are supported.")
+ logger.info(f"{op_name}: Only {supported_dtypes} tensor dtypes are
supported.")
return False
- # assumes dim 0 is for batch and can be dynamic
- # TODO(mbs): But does this depend use_implicit_batch flag?
- for dim_shape in typ.shape[1:]:
- if isinstance(dim_shape, tvm.tir.expr.Any):
+ dims = typ.shape
+ if get_tensorrt_use_implicit_batch_mode():
+ # The first dimension can be Any.
+ dims = dims[1:]
+ for dim in dims:
+ if isinstance(dim, tvm.tir.expr.Any):
logger.info(f"{op_name}: Only statically known tensor shapes
are supported.")
return False
elif isinstance(typ, tvm.ir.TupleType):
@@ -241,13 +235,19 @@ CheckFunc = Callable[[Any, List[relay.expr.Expr], str],
bool]
def make_predicate(checker: CheckFunc) -> Callable[[relay.expr.Expr], bool]:
+ """Returns the pattern predicate which performs the standard checks, then
invokes the
+ more primitive checker."""
+
def predicate(expr: relay.expr.Expr) -> bool:
op_name = get_op_name(expr)
attrs = get_attrs(expr)
args = get_args(expr)
if not all([is_supported_trt_type(arg.checked_type, op_name) for arg
in args]):
return False
- return checker(attrs, args, op_name)
+ if not checker(attrs, args, op_name):
+ return False
+ logger.info(f"{op_name}: Predicate passes")
+ return True
return predicate
@@ -535,11 +535,16 @@ def concatenate_checker(
if int(attrs.axis) == 0:
logger.info(f"{op_name}: can't modify batch dimension.")
return False
- if isinstance(args[0], relay.Tuple):
- for tuple_input in args[0].fields:
- if isinstance(tuple_input, Constant):
- logger.info(f"{op_name}: can't concatenate tensors with
constants.")
- return False
+
+ if not isinstance(args[0], relay.Tuple):
+ logger.info("f{op_name}: concatenate must be applied to a literal
tuple")
+ return False
+
+ for tuple_input in args[0].fields:
+ if isinstance(tuple_input, Constant):
+ logger.info(f"{op_name}: can't concatenate tensors with
constants.")
+ return False
+
return True
diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc
b/src/relay/backend/contrib/codegen_c/codegen.cc
index ee8724fe92..41f0a0a064 100644
--- a/src/relay/backend/contrib/codegen_c/codegen.cc
+++ b/src/relay/backend/contrib/codegen_c/codegen.cc
@@ -360,8 +360,8 @@ class CodegenCModule {
};
/*! \brief The actual translation pass. */
-transform::Pass CCompilerImpl() {
- auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
+tvm::transform::Pass CCompilerImpl() {
+ auto pass_func = [=](IRModule mod, const tvm::transform::PassContext&
pass_ctx) {
VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod);
Target target = GetCCompilerTarget();
@@ -388,10 +388,10 @@ transform::Pass CCompilerImpl() {
return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {});
}
-transform::Pass CCompilerPass() {
+tvm::transform::Pass CCompilerPass() {
return transform::Sequential(
-
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"),
CCompilerImpl(),
- transforms::MarkCompilerFunctionsAsExtern("ccompiler")});
+
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"),
CCompilerImpl(),
+ transform::MarkCompilerFunctionsAsExtern("ccompiler")});
}
} // namespace contrib
diff --git a/src/relay/backend/contrib/cutlass/codegen.cc
b/src/relay/backend/contrib/cutlass/codegen.cc
index de2934173b..2e76ab1cbb 100644
--- a/src/relay/backend/contrib/cutlass/codegen.cc
+++ b/src/relay/backend/contrib/cutlass/codegen.cc
@@ -902,8 +902,8 @@ class CutlassModuleCodegen {
* \brief A small shim to redirect to the
'relay.ext.cutlass.compile_for_cutlass' Python
* function which does the main CUTLASS training, c-code generation and
compilation steps.
*/
-transform::Pass CompileForCutlassImpl() {
- auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
+tvm::transform::Pass CompileForCutlassImpl() {
+ auto pass_func = [=](IRModule mod, const tvm::transform::PassContext&
pass_ctx) {
VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod);
const auto* pf =
runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass");
ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function";
@@ -926,10 +926,10 @@ runtime::Module CreateCSourceModule(const IRModule& mod) {
TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule);
-transform::Pass CompileForCutlass() {
+tvm::transform::Pass CompileForCutlass() {
return transform::Sequential(
-
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
- CompileForCutlassImpl(),
transforms::MarkCompilerFunctionsAsExtern("cutlass")});
+ {transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
+ CompileForCutlassImpl(),
transform::MarkCompilerFunctionsAsExtern("cutlass")});
}
} // namespace cutlass
diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc
b/src/relay/backend/contrib/tensorrt/codegen.cc
index e08cd240d4..1c4a8d7806 100644
--- a/src/relay/backend/contrib/tensorrt/codegen.cc
+++ b/src/relay/backend/contrib/tensorrt/codegen.cc
@@ -29,6 +29,7 @@
#include <string>
#include <vector>
+#include "../../../transforms/compiler_function_utils.h"
#include "../../utils.h"
#include "../codegen_json/codegen_json.h"
@@ -39,36 +40,49 @@
namespace tvm {
namespace relay {
namespace contrib {
+namespace tensorrt {
-/*! \brief Attributes to store the compiler options for TensorRT. */
-struct TensorRTCompilerConfigNode : public
tvm::AttrsNode<TensorRTCompilerConfigNode> {
- Array<Integer> tensorrt_version;
- bool use_implicit_batch;
- size_t max_workspace_size;
- bool remove_no_mac_subgraphs;
- bool use_fp16;
- bool use_uint8;
-
- TVM_DECLARE_ATTRS(TensorRTCompilerConfigNode,
"ext.attrs.TensorRTCompilerConfigNode") {
- TVM_ATTR_FIELD(tensorrt_version)
- .describe("TensorRT version as (major, minor, patch).")
- .set_default(Array<Integer>({6, 0, 1}));
- TVM_ATTR_FIELD(use_implicit_batch).set_default(true);
- TVM_ATTR_FIELD(max_workspace_size).set_default(size_t(1) << 30);
- TVM_ATTR_FIELD(remove_no_mac_subgraphs).set_default(false);
- TVM_ATTR_FIELD(use_fp16).set_default(false);
- TVM_ATTR_FIELD(use_uint8).set_default(false);
- }
-};
+/*!
+ * \brief Check whether TensorRT graph executor is enabled.
+ * \return True if enabled, False if not.
+ */
+inline constexpr bool IsRuntimeEnabled() {
+#if TVM_GRAPH_EXECUTOR_TENSORRT
+ return true;
+#else
+ return false;
+#endif // TVM_GRAPH_EXECUTOR_TENSORRT
+}
-class TensorRTCompilerConfig : public Attrs {
- public:
- TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorRTCompilerConfig, Attrs,
- TensorRTCompilerConfigNode);
-};
+TVM_REGISTER_GLOBAL("relay.ext.tensorrt.is_runtime_enabled").set_body_typed(IsRuntimeEnabled);
-TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode);
-TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.tensorrt.options",
TensorRTCompilerConfig);
+/*!
+ * \brief Get TensorRT version that TVM is built against.
+ * \return Array of three integers for major, minor, and patch, or empty array
if TensorRT graph
+ * runtime is not enabled.
+ */
+Array<Integer> GetVersion() {
+#if TVM_GRAPH_EXECUTOR_TENSORRT
+ return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR),
Integer(NV_TENSORRT_PATCH)};
+#else
+ return {};
+#endif // TVM_GRAPH_EXECUTOR_TENSORRT
+}
+
+TVM_REGISTER_GLOBAL("relay.ext.tensorrt.get_version").set_body_typed(GetVersion);
+
+/*!
+ * \brief Returns the "tensorrt" Target instance to use for compilation.
+ */
+Target GetTensorRTTarget() {
+ Target target = Target::Current(/*allow_not_defined=*/true);
+ if (!target.defined() || target->kind->name != "tensorrt") {
+ // Since we allow partition_for_tensorrt to use the default "tensorrt"
target, we should
+ // similarly allow the custom pass to execute without a specific
"tensorrt" target in scope.
+ target = Target("tensorrt");
+ }
+ return target;
+}
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
@@ -87,6 +101,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor {
explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer)
: serializer_(serializer), node_(std::make_shared<JSONGraphNode>()) {}
+ // We'll need to implement these out-of-band since they use the serializer.
void VisitExpr_(const ConstantNode* constant_node) final;
void VisitExpr_(const CallNode* call_node) final;
@@ -190,6 +205,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor
{
extractor.Extract(const_cast<Object*>(attr_obj));
}
+ /*! \brief The parent serializer for the overall TensorRT partition. */
TensorRTJSONSerializer* serializer_;
/*! \brief Accumulated translated arguments. */
std::vector<JSONGraphNodeEntry> args_;
@@ -207,9 +223,10 @@ class CollectFromCompositeFunctionBody : public
ExprVisitor {
*/
class TensorRTJSONSerializer : public JSONSerializer {
public:
- TensorRTJSONSerializer(const std::string& symbol, const Expr& expr)
- : JSONSerializer(symbol, expr) {}
+ TensorRTJSONSerializer(Target target, const std::string& symbol, const Expr&
expr)
+ : JSONSerializer(symbol, expr), target_(std::move(target)) {}
+ private:
using JSONSerializer::VisitExpr_;
std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* call_node) final {
@@ -245,40 +262,62 @@ class TensorRTJSONSerializer : public JSONSerializer {
node->CaptureAttrs(*collector.node_);
// Capture global settings on the JSON node.
- SaveGlobalAttributes(node);
+ // TODO(mbs): Why on every call?
+ SaveGlobalAttributes(node.get());
VLOG(1) << name << " has " << node->GetInputs().size() << " inputs";
return AddNode(node, GetRef<Expr>(call_node));
}
- static void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
- auto ctx = transform::PassContext::Current();
- auto cfg =
ctx->GetConfig<TensorRTCompilerConfig>("relay.ext.tensorrt.options");
- if (!cfg.defined()) {
- cfg = AttrsWithDefaultValues<TensorRTCompilerConfig>();
+ static void SetAttr(JSONGraphNode* node, const std::string& key,
+ std::vector<std::string> values) {
+ node->SetAttr(key, std::vector<dmlc::any>({std::move(values)}));
+ }
+
+ /*! \brief Capture the compilation options as attributes on \p node. */
+ void SaveGlobalAttributes(JSONGraphNode* node) {
+ {
+ // cf logic in tensorrt.py::get_tensorrt_version.
+ // First check for version in target.
+ Array<Integer> target_attr =
target_->GetAttr<Array<Integer>>("tensorrt_version").value();
+ if (target_attr.empty()) {
+ // Next, ask runtime for its version.
+ target_attr = GetVersion();
+ }
+ if (target_attr.empty()) {
+ // Finally, use default.
+ target_attr = {6, 0, 1};
+ }
+ ICHECK_EQ(target_attr.size(), 3);
+ SetAttr(node, "tensorrt_version",
+ {std::to_string(target_attr[0]), std::to_string(target_attr[1]),
+ std::to_string(target_attr[2])});
+ }
+
+ {
+ Bool target_attr = target_->GetAttr<Bool>("use_implicit_batch").value();
+ SetAttr(node, "use_implicit_batch",
{std::to_string(target_attr->value)});
+ }
+
+ {
+ Integer target_attr =
target_->GetAttr<Integer>("max_workspace_size").value();
+ SetAttr(node, "max_workspace_size",
{std::to_string(target_attr->value)});
+ }
+
+ {
+ Bool target_attr = target_->GetAttr<Bool>("use_fp16").value();
+ SetAttr(node, "use_fp16", {std::to_string(target_attr->value)});
+ }
+
+ {
+ Bool target_attr = target_->GetAttr<Bool>("use_uint8").value();
+ SetAttr(node, "use_uint8", {std::to_string(target_attr->value)});
}
- ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3);
- std::vector<std::string> tensorrt_version =
{std::to_string(cfg.value()->tensorrt_version[0]),
-
std::to_string(cfg.value()->tensorrt_version[1]),
-
std::to_string(cfg.value()->tensorrt_version[2])};
- std::vector<std::string> use_implicit_batch =
{std::to_string(cfg.value()->use_implicit_batch)};
- std::vector<std::string> max_workspace_size =
{std::to_string(cfg.value()->max_workspace_size)};
- std::vector<std::string> use_fp16 =
{std::to_string(cfg.value()->use_fp16)};
- std::vector<std::string> use_uint8 =
{std::to_string(cfg.value()->use_uint8)};
- std::vector<dmlc::any> tensorrt_version_attr, use_implicit_batch_attr,
max_workspace_size_attr,
- use_fp16_attr, use_uint8_attr;
- tensorrt_version_attr.emplace_back(tensorrt_version);
- use_implicit_batch_attr.emplace_back(use_implicit_batch);
- max_workspace_size_attr.emplace_back(max_workspace_size);
- use_fp16_attr.emplace_back(use_fp16);
- use_uint8_attr.emplace_back(use_uint8);
- node->SetAttr("tensorrt_version", tensorrt_version_attr);
- node->SetAttr("use_implicit_batch", use_implicit_batch_attr);
- node->SetAttr("max_workspace_size", max_workspace_size_attr);
- node->SetAttr("use_fp16", use_fp16_attr);
- node->SetAttr("use_uint8", use_uint8_attr);
}
+
+ /*! \brief The "tensorrt" Target guiding compilation. */
+ Target target_;
};
void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode*
constant_node) {
@@ -304,64 +343,74 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const
CallNode* call_node) {
}
/*!
- * \brief Create a runtime module for TensorRT.
- * \param ref The ext_func Relay expression/module to be executed using extern
ops.
- * \return A runtime module.
- */
-runtime::Module TensorRTCompiler(const ObjectRef& ref) {
- ICHECK(ref->IsInstance<FunctionNode>()) << "The input ref is expected to be
a Relay function.";
- Function func = Downcast<Function>(ref);
- std::string func_name = backend::GetExtSymbol(func);
-
- VLOG(1) << "TensorRT partition:" << std::endl << PrettyPrint(func);
- TensorRTJSONSerializer serializer(func_name, func);
- serializer.serialize();
- std::string graph_json = serializer.GetJSON();
- VLOG(1) << "TensorRT JSON:" << std::endl << graph_json;
-
- // Note that serializer.const_name_to_constant() is ignored. Instead the
TECompiler invokes
- // a callback which calls backend::UpdateConstants to capture the map before
the function
- // 'disappears' into lowered form, on the assumption the visit order and
thus constant
- // names match those generated by the JSONSerializer.
-
- const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
- ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create
function.";
- VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'";
- runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names());
- return lib;
-}
-
-TVM_REGISTER_GLOBAL("relay.ext.tensorrt").set_body_typed(TensorRTCompiler);
-
-/*!
- * \brief Check whether TensorRT graph executor is enabled.
- * \return True if enabled, False if not.
+ * \brief The main TensorRT compiler.
+ *
+ * TODO(mbs): Currently we create a \p TensorRTRuntimeModule for every
function with
+ * Compiler="tensorrt" (ie for each partition). Since the TensorRT engine is
only designed to
+ * handle a single entry point this is mostly sensible, however there are
probably opportunities
+ * for more sharing between functions. However, note this means each call to a
TensorRT-compiled
+ * function will require a linear scan of imported runtime modules to find the
matching
+ * TensorRTRuntimeModule implementing it.
*/
-inline constexpr bool IsTensorRTRuntimeEnabled() {
-#if TVM_GRAPH_EXECUTOR_TENSORRT
- return true;
-#else
- return false;
-#endif // TVM_GRAPH_EXECUTOR_TENSORRT
+tvm::transform::Pass CompileForTensorRTImpl() {
+ auto pass_func = [](IRModule mod, const tvm::transform::PassContext&
pass_ctx) {
+ VLOG(1) << "CompileForTensorRT input:" << std::endl << PrettyPrint(mod);
+ Target target = GetTensorRTTarget();
+
+ const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
+ ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create
function.";
+
+ // The accumulated external runtime modules.
+ Array<runtime::Module> external_mods =
+
mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});
+ // The accumulated constant bindings.
+ Map<String, runtime::NDArray> const_name_to_constant =
+ mod->GetAttr<Map<String,
runtime::NDArray>>(tvm::attr::kConstNameToConstant).value_or({});
+
+ for (const auto& kv : mod->functions) {
+ if (const auto* function_node = kv.second.as<FunctionNode>()) {
+ if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+ Optional<String> opt_compiler =
function_node->GetAttr<String>(attr::kCompiler);
+ if (opt_compiler && opt_compiler.value() == "tensorrt") {
+ // Serialize the function to JSON.
+ TensorRTJSONSerializer serializer(target, kv.first->name_hint,
+ GetRef<Function>(function_node));
+ serializer.serialize();
+ std::string graph_json = serializer.GetJSON();
+ VLOG(1) << "TensorRT JSON for '" << kv.first->name_hint << "':" <<
std::endl
+ << graph_json;
+
+ // Remember all the constant bindings.
+ for (const auto& kv2 : serializer.const_name_to_constant()) {
+ ICHECK_EQ(const_name_to_constant.count(kv2.first), 0);
+ VLOG(1) << "binding constant '" << kv2.first << "' for function
'"
+ << kv.first->name_hint << "'";
+ const_name_to_constant.Set(kv2.first, kv2.second);
+ }
+
+ // Create the actual runtime module.
+ runtime::Module runtime_mod =
+ (*pf)(kv.first->name_hint, graph_json,
serializer.const_names());
+
+ // Remember the runtime module.
+ external_mods.push_back(runtime_mod);
+ }
+ }
+ }
+ }
+ return WithAttrs(mod, {{tvm::attr::kExternalMods, external_mods},
+ {tvm::attr::kConstNameToConstant,
const_name_to_constant}});
+ };
+ return tvm::transform::CreateModulePass(pass_func, 0, "CompileForTensorRT",
{});
}
-/*!
- * \brief Get TensorRT version that TVM is built against.
- * \return Array of three integers for major, minor, and patch, or empty array
if TensorRT graph
- * runtime is not enabled.
- */
-Array<Integer> GetTensorRTVersion() {
-#if TVM_GRAPH_EXECUTOR_TENSORRT
- return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR),
Integer(NV_TENSORRT_PATCH)};
-#else
- return {};
-#endif // TVM_GRAPH_EXECUTOR_TENSORRT
+tvm::transform::Pass CompileForTensorRT() {
+ return transform::Sequential(
+
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"),
+ CompileForTensorRTImpl(),
transform::MarkCompilerFunctionsAsExtern("tensorrt")});
}
-TVM_REGISTER_GLOBAL("relay.op.is_tensorrt_runtime_enabled")
- .set_body_typed(IsTensorRTRuntimeEnabled);
-TVM_REGISTER_GLOBAL("relay.op.get_tensorrt_version").set_body_typed(GetTensorRTVersion);
-
+} // namespace tensorrt
} // namespace contrib
} // namespace relay
} // namespace tvm
diff --git a/src/relay/backend/contrib/tensorrt/target.cc
b/src/relay/backend/contrib/tensorrt/codegen.h
similarity index 57%
copy from src/relay/backend/contrib/tensorrt/target.cc
copy to src/relay/backend/contrib/tensorrt/codegen.h
index 85d127ab71..813a866375 100644
--- a/src/relay/backend/contrib/tensorrt/target.cc
+++ b/src/relay/backend/contrib/tensorrt/codegen.h
@@ -18,25 +18,30 @@
*/
/*!
- * \file src/relay/backend/contrib/tensorrt/target.cc
- * \brief Registers the "tensorrt" external codegen TargetKind.
+ * \file src/relay/backend/contrib/tensorrt/codegen.h
+ * \brief The 'custom' compilation pass for TensorRT (invoked by the
RelayToTIRTargetHook pass).
*/
-#include <tvm/target/target.h>
+#ifndef TVM_RELAY_BACKEND_CONTRIB_TENSORRT_CODEGEN_H_
+#define TVM_RELAY_BACKEND_CONTRIB_TENSORRT_CODEGEN_H_
+
+#include <tvm/ir/transform.h>
namespace tvm {
namespace relay {
namespace contrib {
+namespace tensorrt {
/*!
- * \brief This external codegen target can offload compilation to the TensorRT
compiler.
- * - Patterns: python/tvm/relay/op/contrib/tensorrt.py
- * - Custom compiler: src/relay/backend/contrib/tensorrt/codegen.cc
- * - Runtime: src/runtime/contrib/tensorrt/ *.cc
+ * \brief Returns the pass which replaces all calls to "Primitive" functions
with a "Compiler"
+ * attribute of "tensorrt" with calls to an extern which is implemented by a
\p TensorRTRuntime
+ * runtime module added to the IRModule's "external_mods" attribute.
*/
-TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA)
- .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true));
+transform::Pass CompileForTensorRT();
+} // namespace tensorrt
} // namespace contrib
} // namespace relay
} // namespace tvm
+
+#endif // TVM_RELAY_BACKEND_CONTRIB_TENSORRT_CODEGEN_H_
diff --git a/src/relay/backend/contrib/tensorrt/target.cc
b/src/relay/backend/contrib/tensorrt/target.cc
index 85d127ab71..2e4581d30a 100644
--- a/src/relay/backend/contrib/tensorrt/target.cc
+++ b/src/relay/backend/contrib/tensorrt/target.cc
@@ -24,19 +24,46 @@
#include <tvm/target/target.h>
+#include "./codegen.h"
+
namespace tvm {
namespace relay {
namespace contrib {
+namespace tensorrt {
/*!
* \brief This external codegen target can offload compilation to the TensorRT
compiler.
* - Patterns: python/tvm/relay/op/contrib/tensorrt.py
* - Custom compiler: src/relay/backend/contrib/tensorrt/codegen.cc
- * - Runtime: src/runtime/contrib/tensorrt/ *.cc
+ * - Runtime: src/runtime/contrib/tensorrt/...
*/
TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA)
- .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true));
+ .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
+ .set_attr<FTVMRelayToTIR>("RelayToTIR", CompileForTensorRT())
+ // A array of three integers given the major, minor, and patch numbers for
the supported
+ // TensorRT compiler version. If empty will be auto-detected from linked
library. Default empty.
+ .add_attr_option<Array<Integer>>("tensorrt_version", Array<Integer>())
+ // If true, the first tensor dimension for most operators is allowed to be
Any and
+ // TensorRT will assume it represents a batch dimension only known at
inference time.
+ // Fewer Relay operators are supported in implicit batch mode. Default
true.
+ .add_attr_option<Bool>("use_implicit_batch", Bool(true))
+ // If true, excludes sub-graphs which do not have multiply-accumulate
operations, even though
+ // TensorRT supports them. ad. This is a simple heuristic to optimize the
partitioning between
+ // TensorRT and TVM. Not required if using Collage for partitioning.
Defalut false.
+ .add_attr_option<Bool>("remove_no_mac_subgraphs", Bool(false))
+ // How many bytes of workspace size to allow each subgraph to use for
TensorRT engine creation.
+ // Default 1G.
+ .add_attr_option<Integer>("max_workspace_size", Integer(1 << 30))
+ // If true, allows TensorRT to automatically convert float32 operations to
float16. Must also be
+ // enabled if any float16 operations are in the model. Note that TensorRT
may still choose a
+ // higher-precision kernel if it results in overall lower runtime, or if
no low-precision
+ // implementation exists. Default false.
+ .add_attr_option<Bool>("use_fp16", Bool(false))
+ // If true, allows TensorRT to automatically convert float32 operations to
uint8
+ // (aka quantized). Default false.
+ .add_attr_option<Bool>("use_uint8", Bool(false));
+} // namespace tensorrt
} // namespace contrib
} // namespace relay
} // namespace tvm
diff --git a/src/relay/transforms/compiler_function_utils.cc
b/src/relay/transforms/compiler_function_utils.cc
index 0df9f5ee29..1dafcd10a3 100644
--- a/src/relay/transforms/compiler_function_utils.cc
+++ b/src/relay/transforms/compiler_function_utils.cc
@@ -24,14 +24,13 @@
#include "./compiler_function_utils.h"
-#include "../op/call/call.h"
#include "tvm/relay/analysis.h"
#include "tvm/relay/expr_functor.h"
#include "tvm/relay/transform.h"
namespace tvm {
namespace relay {
-namespace transforms {
+namespace transform {
namespace {
/*!
@@ -211,8 +210,8 @@ GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const
Function& function) {
return global_var;
}
-transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache>
cache,
- std::string compiler_filter) {
+tvm::transform::Pass
OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
+ std::string compiler_filter) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)>
pass_func =
[cache = std::move(cache), compiler_filter = std::move(compiler_filter)](
IRModule mod, transform::PassContext ctx) {
@@ -235,12 +234,13 @@ transform::Pass
OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
}
// Any Java programmers in the house?
-transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string
compiler_filter) {
+tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(
+ std::string compiler_filter) {
return
OutlineCompilerFunctions(std::make_shared<ExistingGlobalSymbolCache>(),
std::move(compiler_filter));
}
-transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
+tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string
compiler_filter) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)>
pass_func =
[compiler_filter = std::move(compiler_filter)](IRModule mod,
transform::PassContext ctx) {
VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl <<
PrettyPrint(mod);
@@ -262,7 +262,7 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string
compiler_filter) {
return tvm::transform::CreateModulePass(pass_func, 0,
"MarkCompilerFunctionsAsExtern", {});
}
-transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars) {
+tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar>
global_vars) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)>
pass_func =
[global_vars = std::move(global_vars)](IRModule mod,
transform::PassContext ctx) {
VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " <<
PrettyPrint(global_vars);
@@ -295,6 +295,6 @@
TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern")
TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo")
.set_body_typed(InlineCompilerFunctionsBoundTo);
-} // namespace transforms
+} // namespace transform
} // namespace relay
} // namespace tvm
diff --git a/src/relay/transforms/compiler_function_utils.h
b/src/relay/transforms/compiler_function_utils.h
index aa98430318..f3499faec2 100644
--- a/src/relay/transforms/compiler_function_utils.h
+++ b/src/relay/transforms/compiler_function_utils.h
@@ -66,7 +66,7 @@
namespace tvm {
namespace relay {
-namespace transforms {
+namespace transform {
/*!
* \brief Abstract class representing a cache of unique global vars keyed by
functions. This can
@@ -105,8 +105,8 @@ class ExistingGlobalSymbolCache : public GlobalSymbolCache {
* If \p compiler_filter is non-empty only functions with that as their
attribute value are
* outlined.
*/
-transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache>
cache,
- std::string compiler_filter = "");
+tvm::transform::Pass
OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
+ std::string compiler_filter =
"");
/*!
* \brief A pass to outline all let-bound and literal functions in direct call
positions which have
@@ -119,7 +119,8 @@ transform::Pass
OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
* This pass may be useful for external codegen using the "RelayToTIR" custom
pass mechanism
* to prepare the IRModule before custom lowering.
*/
-transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string
compiler_filter = "");
+tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(
+ std::string compiler_filter = "");
/*!
* \brief A pass to mark all global functions which have a "Compiler"
attribute matching
@@ -132,7 +133,7 @@ transform::Pass
OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
* This pass may be useful for external codegen using the "RelayToTIR" custom
pass mechanism to
* cleanup the IRModule after custom lowering.
*/
-transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter =
"");
+tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter
= "");
/*!
* \brief A pass to inline all global "Compiler" functions which are bound to
a global var
@@ -142,9 +143,9 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string
compiler_filter = "");
* This pass may be useful for external codegen which needs to undo
partitioning based on
* properties of the entire partition.
*/
-transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars);
+tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar>
global_vars);
-} // namespace transforms
+} // namespace transform
} // namespace relay
} // namespace tvm
diff --git a/src/runtime/const_loader_module.cc
b/src/runtime/const_loader_module.cc
index 2e91d26d5f..a8028e616c 100644
--- a/src/runtime/const_loader_module.cc
+++ b/src/runtime/const_loader_module.cc
@@ -51,15 +51,24 @@ class ConstLoaderModuleNode : public ModuleNode {
const std::unordered_map<std::string, NDArray>& const_var_ndarray,
const std::unordered_map<std::string, std::vector<std::string>>&
const_vars_by_symbol)
: const_var_ndarray_(const_var_ndarray),
const_vars_by_symbol_(const_vars_by_symbol) {
+ VLOG(1) << "Creating ConstLoaderModule";
// Only the related submodules are cached to reduce the number of runtime
// symbol lookup for initialization. Otherwise, symbols/primitives in the
// DSO module will also be cached but they never need to be initialized.
- for (const auto& it : const_vars_by_symbol_) {
- initialized_[it.first] = false;
+ for (const auto& kv : const_vars_by_symbol_) {
+ for (const auto& var : kv.second) {
+ VLOG(1) << "ConstLoaderModuleNode has constant '" << var << "' for
function '" << kv.first
+ << "'";
+ ICHECK_GT(const_var_ndarray_.count(var), 0)
+ << "ConstLoaderModuleNode is missing entry for constant '" << var
<< "' for function '"
+ << kv.first << "'";
+ }
+ initialized_[kv.first] = false;
}
}
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final {
+ VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")";
// Initialize and memoize the module.
// Usually, we have some warmup runs. The module initialization should be
// done at this stage. Therefore, runtime overhead is not a concern.
@@ -88,11 +97,13 @@ class ConstLoaderModuleNode : public ModuleNode {
*/
Array<NDArray> GetRequiredConstants(const std::string& symbol) {
Array<NDArray> ret;
- ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) << "No symbol is
recorded for " << symbol;
+ ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U)
+ << "No constants known for function '" << symbol << "'";
std::vector<std::string> vars = const_vars_by_symbol_[symbol];
- for (const auto& it : vars) {
- ICHECK_GT(const_var_ndarray_.count(it), 0U) << "Found not recorded
constant variable: " << it;
- ret.push_back(const_var_ndarray_[it]);
+ for (const auto& var : vars) {
+ ICHECK_GT(const_var_ndarray_.count(var), 0U)
+ << "No such constant variable '" << var << "' for function '" <<
symbol << "'";
+ ret.push_back(const_var_ndarray_[var]);
}
return ret;
}
@@ -229,5 +240,6 @@ TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata")
.set_body_typed(ConstLoaderModuleNode::LoadFromBinary);
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader")
.set_body_typed(ConstLoaderModuleNode::LoadFromBinary);
+
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/contrib/json/json_runtime.h
b/src/runtime/contrib/json/json_runtime.h
index 355390765d..3a02202b87 100644
--- a/src/runtime/contrib/json/json_runtime.h
+++ b/src/runtime/contrib/json/json_runtime.h
@@ -54,6 +54,8 @@ class JSONRuntimeBase : public ModuleNode {
LoadGraph(graph_json_);
}
+ ~JSONRuntimeBase() override = default;
+
const char* type_key() const override { return "json"; } // May be
overridden
/*! \brief Initialize a specific json runtime. */
diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc
b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
index 5f923667d0..436a6db4c8 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
@@ -45,10 +45,11 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
max_workspace_size_(max_workspace_size),
use_implicit_batch_(use_implicit_batch),
use_fp16_(use_fp16),
- batch_size_(batch_size) {
+ use_int8_(false),
+ batch_size_(batch_size),
+ calibrator_(calibrator) {
// Create TRT builder and network.
builder_ = nvinfer1::createInferBuilder(*logger);
- use_int8_ = false;
#if TRT_VERSION_GE(6, 0, 1)
// Use INetworkV2.
@@ -58,8 +59,7 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
flags = 0U;
builder_->setMaxBatchSize(batch_size_);
}
- this->calibrator_ = calibrator;
- if (calibrator != nullptr) {
+ if (calibrator_ != nullptr) {
use_int8_ = true;
}
network_ = builder_->createNetworkV2(flags);
@@ -177,6 +177,7 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
if (use_int8_) {
config_->setFlag(nvinfer1::BuilderFlag::kINT8);
+ ICHECK(calibrator_);
config_->setInt8Calibrator(calibrator_);
LOG(INFO) << "config finishes setting up calibrator as INT8 mode ... ";
}
@@ -210,6 +211,9 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
nvinfer1::IExecutionContext* context = engine->createExecutionContext();
CleanUp();
+ ICHECK(engine);
+ ICHECK(context);
+
return {engine, context, network_input_names_, network_output_names_};
}
@@ -254,18 +258,33 @@ nvinfer1::ITensor*
TensorRTBuilder::GetInputAsTensor(const TensorRTOpInput& inpu
}
void TensorRTBuilder::CleanUp() {
+ VLOG(1) << "Destroying TensorRT network";
+ ICHECK(network_);
network_->destroy();
+ network_ = nullptr;
+
#if TRT_VERSION_GE(6, 0, 1)
+ VLOG(1) << "Destroying TensorRT config";
+ ICHECK(config_);
config_->destroy();
+ config_ = nullptr;
#endif
+
+ VLOG(1) << "Destroying TensorRT builder";
+ ICHECK(builder_);
builder_->destroy();
+ builder_ = nullptr;
+
+ VLOG(1) << "Destroying TensorRT weights";
for (auto weight : trt_weights_) {
+ ICHECK(weight.values);
if (weight.type == nvinfer1::DataType::kFLOAT || weight.type ==
nvinfer1::DataType::kHALF) {
delete[] static_cast<const float*>(weight.values);
} else {
delete[] static_cast<const uint16_t*>(weight.values);
}
}
+ trt_weights_.clear();
}
} // namespace contrib
diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h
b/src/runtime/contrib/tensorrt/tensorrt_builder.h
index 13a118340e..9bccc1ea48 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_builder.h
+++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h
@@ -48,8 +48,8 @@ using JSONGraphNodeEntry =
tvm::runtime::json::JSONGraphNodeEntry;
* perform inference.
*/
struct TensorRTEngineAndContext {
- nvinfer1::ICudaEngine* engine;
- nvinfer1::IExecutionContext* context;
+ nvinfer1::ICudaEngine* engine = nullptr;
+ nvinfer1::IExecutionContext* context = nullptr;
std::vector<std::string> inputs;
std::vector<std::string> outputs;
};
@@ -125,15 +125,15 @@ class TensorRTBuilder {
std::unordered_map<int, std::vector<TensorRTOpInput>> node_output_map_;
/*! \brief TensorRT builder. */
- nvinfer1::IBuilder* builder_;
+ nvinfer1::IBuilder* builder_ = nullptr;
#if TRT_VERSION_GE(6, 0, 1)
/*! \brief TensorRT builder config. */
- nvinfer1::IBuilderConfig* config_;
+ nvinfer1::IBuilderConfig* config_ = nullptr;
#endif
/*! \brief TensorRT network definition. */
- nvinfer1::INetworkDefinition* network_;
+ nvinfer1::INetworkDefinition* network_ = nullptr;
/*! \brief List of all weights held in memory. */
std::vector<nvinfer1::Weights> trt_weights_;
diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc
b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
index 3971081bf8..cd46967e53 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
@@ -67,7 +67,7 @@ nvinfer1::ITensor*
TensorRTOpConverter::Transpose(TensorRTOpConverterParams* par
// Batch dimension cannot be modified.
ICHECK_EQ(input->getDimensions().nbDims, order.size() - 1);
ICHECK_EQ(order[0], 0);
- for (size_t i = 0; i < order.size(); ++i) {
+ for (size_t i = 0; i + 1 < order.size(); ++i) {
perm.order[i] = order[i + 1] - 1;
}
} else {
@@ -880,7 +880,7 @@ class ConcatOpConverter : public TensorRTOpConverter {
const int input_rank = params->inputs[0].tensor->getDimensions().nbDims;
std::vector<nvinfer1::ITensor*> input_tensors;
for (auto input : params->inputs) {
- ICHECK(input.type == kTensor);
+ ICHECK_EQ(input.type, kTensor);
ICHECK_EQ(input_rank, input.tensor->getDimensions().nbDims);
input_tensors.push_back(input.tensor);
}
diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
index 18ffdbbbba..b51684b95e 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
@@ -138,13 +138,21 @@ class TensorRTRuntime : public JSONRuntimeBase {
/*! \brief Destroy engines and contexts. */
void DestroyEngines() {
for (auto& it : trt_engine_cache_) {
+ VLOG(1) << "Destroying TensorRT context for function '" <<
it.first.first << "' (batch size "
+ << it.first.second << ")";
it.second.context->destroy();
+ VLOG(1) << "Destroying TensorRT engine for function '" << it.first.first
<< "' (batch size "
+ << it.first.second << ")";
it.second.engine->destroy();
}
trt_engine_cache_.clear();
}
- ~TensorRTRuntime() { DestroyEngines(); }
+ ~TensorRTRuntime() override {
+ VLOG(1) << "Destroying TensorRT runtime";
+ DestroyEngines();
+ VLOG(1) << "Destroyed TensorRT runtime";
+ }
/*! \brief Run inference using built engine. */
void Run() override {
@@ -467,7 +475,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
/*! \brief TensorRT logger. */
TensorRTLogger logger_;
-#else
+#else // TVM_GRAPH_EXECUTOR_TENSORRT
void Run() override {
LOG(FATAL) << "TensorRT runtime is not enabled. "
<< "Please build with USE_TENSORRT_RUNTIME.";
@@ -481,7 +489,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
bool GetCachedEnginesFromDisk() { return false; }
void CacheEngineToDisk() {}
-#endif
+#endif // TVM_GRAPH_EXECUTOR_TENSORRT
bool use_implicit_batch_;
diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc
index ec301d1081..e5ca82d5c0 100644
--- a/src/target/metadata_module.cc
+++ b/src/target/metadata_module.cc
@@ -215,8 +215,6 @@ runtime::Module CreateMetadataModule(
String symbol = pf_sym();
Array<String> variables = pf_var();
for (size_t i = 0; i < variables.size(); i++) {
- VLOG(1) << "From module of type '" << mod->type_key() << "' found
const var '"
- << variables[i] << "' for symbol '" << symbol << "'";
symbol_const_vars.push_back(variables[i].operator std::string());
}
ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) << "Found duplicated
symbol: " << symbol;
diff --git a/tests/python/contrib/test_tensorrt.py
b/tests/python/contrib/test_tensorrt.py
index cecb64785a..9e39821fd3 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -14,31 +14,37 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import tvm.testing
+
import numpy as np
import pytest
import itertools
+import logging
+from typing import Tuple
+try:
+ # See issue #9362.
+ import torch
+except:
+ pass
import tvm
+import tvm.testing
import tvm.relay.testing
from tvm import relay
-from tvm.relay.op.contrib import tensorrt
-
from tvm.relay import Any, GlobalVar
-
from tvm.relay.expr_functor import ExprVisitor
-from typing import Tuple
from tvm.contrib.download import download
from tvm.relay.op.contrib import tensorrt
-
SUPPORTED_DTYPES = ["float16", "float32"]
has_tensorrt_codegen = pytest.mark.skipif(
- not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT
codegen not available"
+ not tensorrt.is_tensorrt_compiler_enabled(), reason="TensorRT codegen not
available"
)
+
+# CAUTION: Currently always false in CI since adds tens of minutes to test
time and depends
+# on TensorRT installation. See https://github.com/apache/tvm/issues/11765
has_tensorrt_runtime = pytest.mark.skipif(
not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not
available"
)
@@ -72,7 +78,7 @@ def assert_result_dict_holds(result_dict, dtype="float16"):
tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=5e-3)
-def set_func_attr(func, compile_name, symbol_name):
+def set_outer_func_attr(func, compile_name, symbol_name):
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", compile_name)
@@ -80,6 +86,12 @@ def set_func_attr(func, compile_name, symbol_name):
return func
+def set_inner_func_attr(func, pattern_name, composite_name):
+ func = func.with_attr("PartitionedFromPattern", pattern_name)
+ func = func.with_attr("Composite", composite_name)
+ return func
+
+
def run_and_verify_func(config, target="cuda", run_module=True,
data_type="float32"):
"""Test a Relay func by compiling, running, and comparing TVM and TRT
outputs.
@@ -110,34 +122,31 @@ def run_and_verify_func(config, target="cuda",
run_module=True, data_type="float
result_dict = dict()
for mode in ["vm", "graph"]:
- for mode in ["graph"]:
- for use_trt in [True, False]:
- mod = tvm.IRModule()
- mod["main"] = f
- result_key = mode + ("_trt" if use_trt else "")
- if use_trt:
- mod = relay.transform.InferType()(mod)
- mod, config = tensorrt.partition_for_tensorrt(
- mod, params, use_fp16=data_type == "float16"
- )
- with tvm.transform.PassContext(
- opt_level=3, config={"relay.ext.tensorrt.options":
config}
- ):
- func = relay.create_executor(
- mode, mod=mod, device=dev, target=target
- ).evaluate()
- else:
- mod = relay.transform.InferType()(mod)
- with tvm.transform.PassContext(opt_level=3):
- func = relay.create_executor(
- mode, mod=mod, device=dev, target=target
- ).evaluate()
+ for use_trt in [True, False]:
+ mod = tvm.IRModule()
+ mod["main"] = f
+ result_key = mode + ("_trt" if use_trt else "")
+ if use_trt:
+ use_fp16 = data_type == "float16"
+ trt_target = tvm.target.Target(f"tensorrt
-use_fp16={use_fp16}")
+ mod = relay.transform.InferType()(mod)
+ mod = tensorrt.partition_for_tensorrt(mod, params=params,
target=trt_target)
+ with tvm.transform.PassContext(opt_level=3):
+ func = relay.create_executor(
+ mode, mod=mod, device=dev, target=[target, trt_target]
+ ).evaluate()
+ else:
+ mod = relay.transform.InferType()(mod)
+ with tvm.transform.PassContext(opt_level=3):
+ func = relay.create_executor(
+ mode, mod=mod, device=dev, target=target
+ ).evaluate()
- if run_module:
- result_dict[result_key] = func(**input_dict, **params)
+ if run_module:
+ result_dict[result_key] = func(**input_dict, **params)
- if run_module:
- assert_result_dict_holds(result_dict, data_type)
+ if run_module:
+ assert_result_dict_holds(result_dict, data_type)
def test_tensorrt_simple(run_module):
@@ -163,10 +172,8 @@ def test_tensorrt_simple(run_module):
result_key = mode + ("_trt" if use_trt else "")
if use_trt:
mod = relay.transform.InferType()(mod)
- mod, config = tensorrt.partition_for_tensorrt(mod)
- with tvm.transform.PassContext(
- opt_level=3, config={"relay.ext.tensorrt.options":
config}
- ):
+ mod = tensorrt.partition_for_tensorrt(mod)
+ with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=mod, device=tvm.cuda(0), target="cuda"
).evaluate()
@@ -212,9 +219,9 @@ def test_tensorrt_not_compatible(run_module):
f = relay.Function([x], out)
mod = tvm.IRModule()
mod["main"] = f
- mod, config = tensorrt.partition_for_tensorrt(mod)
+ mod = tensorrt.partition_for_tensorrt(mod)
for mode in ["graph", "vm"]:
- with tvm.transform.PassContext(opt_level=3,
config={"relay.ext.tensorrt.options": config}):
+ with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=mod, device=tvm.cuda(0), target="cuda"
).evaluate()
@@ -622,26 +629,18 @@ class AreOpsOnGraph(ExprVisitor):
def are_ops_on_trt(mod, op_list):
+ op_on_trt = False
+ op_on_tvm = False
for subgraph in mod.get_global_vars():
name = subgraph.name_hint
- op_on_trt = False
- op_on_tvm = True
- if name == "main":
- op_on_tvm = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
- elif mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt":
- op_on_trt = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
+ if mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt":
+ op_on_trt |=
AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
else:
- op_on_tvm &=
AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
-
- if not op_on_trt or op_on_tvm:
- return False
+ op_on_tvm |=
AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
- return True
+ return op_on_trt and not op_on_tvm
[email protected](
- reason=("Currently failing test. See tracking issue
https://github.com/apache/tvm/issues/8901")
-)
def test_dynamic_reshape(run_module):
def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt):
result_arr = [{} for _ in range(len(x_data_list))]
@@ -652,9 +651,9 @@ def test_dynamic_reshape(run_module):
mod = tvm.IRModule()
mod["main"] = f
if use_trt:
- mod, _ = tensorrt.partition_for_tensorrt(
- mod, params={}, remove_no_mac_subgraphs=False
- )
+ logging.info("Before partitioning:\n%s", mod)
+ mod = tensorrt.partition_for_tensorrt(mod)
+ logging.info("After partitioning:\n%s", mod)
assert are_ops_on_trt(mod, op_list=["reshape"]) ==
should_offload_to_trt
if run_module:
with relay.build_config(opt_level=3):
@@ -1051,6 +1050,7 @@ def test_multiple_outputs(run_module):
run_and_verify_func(get_graph(d_type=type), run_module=run_module,
data_type=type)
[email protected](reason=("Fails assert_allclose. See
https://github.com/apache/tvm/issues/11765"))
def test_conv3d(run_module):
def get_graph(
x_shape=(1, 24, 8, 8, 8),
@@ -1143,11 +1143,7 @@ def test_conv3d_transpose(run_module):
)
[email protected](
- reason=("Currently failing test. See tracking issue
https://github.com/apache/tvm/issues/8901")
-)
@has_tensorrt_codegen
[email protected]_cuda
def test_dynamic_offload():
"""
This test checks for proper dynamic offloading of relay graphs. An
addition between
@@ -1161,24 +1157,29 @@ def test_dynamic_offload():
x = relay.var("x", shape=(data_shape[0], data_shape[1], Any(), Any()),
dtype="float32")
y = relay.var("y", shape=(data_shape), dtype="float32")
- kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
+ kernel = relay.const(np.random.rand(*k_shape).astype("float32"))
def get_expected():
# Create a nested TRT function that matches the expected output
mod = tvm.IRModule()
- var1 = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32")
- kernel_trt = relay.var("tensorrt_0_i1", shape=(k_shape),
dtype="float32")
- out1 = relay.nn.conv2d(var1, kernel_trt, channels=k_shape[0],
kernel_size=k_shape[2:4])
- f1 = GlobalVar("tvmgen_default_tensorrt_0")
- func = relay.Function([var1, kernel_trt], out1)
- func = set_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0")
- mod[f1] = func
+ outer_var = relay.var("tensorrt_0_i0", shape=(data_shape),
dtype="float32")
+ inner_var = relay.var("FunctionVar_0_0", shape=(data_shape),
dtype="float32")
+ inner_body = relay.nn.conv2d(
+ inner_var, kernel, channels=k_shape[0], kernel_size=k_shape[2:4]
+ )
+ inner_func = relay.Function([inner_var], inner_body)
+ inner_func = set_inner_func_attr(inner_func, "nn.conv2d_",
"tensorrt.nn.conv2d")
+ outer_body = inner_func(outer_var)
+ outer_func = relay.Function([outer_var], outer_body)
+ outer_func = set_outer_func_attr(outer_func, "tensorrt",
"tvmgen_default_tensorrt_main_0")
+ gv = GlobalVar("tvmgen_default_tensorrt_main_0")
+ mod[gv] = outer_func
mod = relay.transform.InferType()(mod)
# Create the main function
out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0],
kernel_size=k_shape[2:4])
- out = relay.add(out1, f1(y, kernel))
- f = relay.Function([x, y, kernel], out)
+ out = relay.add(out1, gv(y))
+ f = relay.Function([x, y], out)
mod["main"] = f
mod = relay.transform.InferType()(mod)
return mod
@@ -1187,13 +1188,13 @@ def test_dynamic_offload():
out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0],
kernel_size=k_shape[2:4])
out2 = relay.nn.conv2d(y, kernel, channels=k_shape[0],
kernel_size=k_shape[2:4])
out = relay.add(out1, out2)
- f = relay.Function([x, y, kernel], out)
+ f = relay.Function([x, y], out)
# Pass the function to TRT compilation
mod = tvm.IRModule()
mod["main"] = f
mod = relay.transform.InferType()(mod)
- mod_trt, config = tensorrt.partition_for_tensorrt(mod, params={})
+ mod_trt = tensorrt.partition_for_tensorrt(mod)
# Get the expected relay graph and compare
mod_exp = get_expected()
@@ -1212,7 +1213,7 @@ def test_tensorrt_dynamic_batch(run_module):
mod = tvm.IRModule()
mod["main"] = f
if use_trt:
- mod, _ = tensorrt.partition_for_tensorrt(mod)
+ mod = tensorrt.partition_for_tensorrt(mod)
if run_module:
with relay.build_config(opt_level=3):
@@ -1242,17 +1243,17 @@ def test_tensorrt_dynamic_batch_conv(run_module):
f = relay.Function([x, kernel], out)
mod = tvm.IRModule()
mod["main"] = f
+ trt_target = tvm.target.Target(f"tensorrt
-use_implicit_batch={use_implicit_batch}")
if use_trt:
- mod, config = tensorrt.partition_for_tensorrt(
- mod, params, use_implicit_batch=use_implicit_batch
- )
+ mod = tensorrt.partition_for_tensorrt(mod, params=params,
target=trt_target)
if run_module:
for target in ["llvm", "cuda"]:
- with tvm.transform.PassContext(
- opt_level=3, config={"relay.ext.tensorrt.options":
config}
- ):
+ targets = [target]
+ if use_trt:
+ targets.append(trt_target)
+ with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
- "vm", mod=mod, device=tvm.device(target),
target=target
+ "vm", mod=mod, device=tvm.device(target),
target=targets
).evaluate()
for i, batch_size in enumerate(batches_to_test):
result_arr[i][target][use_trt] =
func(x_data[:batch_size, ...], **params)
@@ -1281,9 +1282,11 @@ def test_maskrcnn_resnet50(run_module) -> None:
input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(traced_module, shape_list)
- mod, config = tensorrt.partition_for_tensorrt(mod, params,
remove_no_mac_subgraphs=True)
+ trt_target = tvm.target.Target("tensorrt
-remove_no_mac_subgraphs=True")
+ mod = tensorrt.partition_for_tensorrt(mod, params=params,
target=trt_target)
+ targets = [target, trt_target]
with tvm.transform.PassContext(opt_level=3,
disabled_pass=["FoldScaleAxis"]):
- vm_trt_exec = relay.vm.compile(mod, target=target, params=params)
+ vm_trt_exec = relay.vm.compile(mod, target=targets, params=params)
return vm_trt_exec
@@ -1381,7 +1384,7 @@ def test_empty_subgraph(run_module):
var1 = relay.var("tensorrt_0_i0", shape=(x_shape), dtype="float32")
f1 = GlobalVar("tensorrt_0")
func = relay.Function([var1], var1)
- func = set_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0")
+ func = set_outer_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0")
mod[f1] = func
mod = relay.transform.InferType()(mod)
@@ -1402,4 +1405,5 @@ def test_empty_subgraph(run_module):
if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO)
tvm.testing.main()
diff --git a/tests/python/contrib/test_tensorrt_int8_exp.py
b/tests/python/contrib/test_tensorrt_int8_exp.py
index 84360e92d3..304d9a095e 100644
--- a/tests/python/contrib/test_tensorrt_int8_exp.py
+++ b/tests/python/contrib/test_tensorrt_int8_exp.py
@@ -18,8 +18,14 @@ import pytest
import os
import numpy as np
+try:
+ # See issue #9362.
+ import torch
+except:
+ pass
+
import tvm
-import tvm.relay.testing
+import tvm.testing
from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
@@ -31,9 +37,10 @@ def skip_codegen_test():
if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist:
print("Skip because CUDA is not enabled.")
return True
- if not tvm.get_global_func("relay.ext.tensorrt", True):
- print("Skip because TensorRT codegen is not available.")
+ if not tensorrt.is_tensorrt_compiler_enabled():
+ print("Skip because TensorRT compiler is not available.")
return True
+ print("TensorRT compiler is available!")
return False
@@ -44,6 +51,7 @@ def skip_runtime_test():
if not tensorrt.is_tensorrt_runtime_enabled():
print("Skip because TensorRT runtime is not available.")
return True
+ print("TensorRT runtime is available!")
return False
@@ -102,12 +110,11 @@ def test_trt_int8():
# compile the model
target = "cuda"
- dev = tvm.cuda(1)
- mod, config = partition_for_tensorrt(mod, params)
- with tvm.transform.PassContext(opt_level=3,
config={"relay.ext.tensorrt.options": config}):
+ dev = tvm.cuda()
+ mod = partition_for_tensorrt(mod, params)
+ with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
- dtype = "float32"
gen_module = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
num_cali_int8 = int(os.environ["TENSORRT_NUM_CALI_INT8"])
@@ -146,4 +153,4 @@ def test_trt_int8():
if __name__ == "__main__":
- pytest.main([__file__])
+ tvm.testing.main()