This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 139921e709 [Unity][BYOC][Pass] RunCodegen and TensorRT (#14078)
139921e709 is described below
commit 139921e709ac421806061a9f11469062810ea688
Author: Sunghyun Park <[email protected]>
AuthorDate: Tue Feb 21 19:14:52 2023 -0800
[Unity][BYOC][Pass] RunCodegen and TensorRT (#14078)
This PR introduces the fundamental workflow for BYOC and integrate TensorRT
as a demonstration.
---
cmake/modules/contrib/TensorRT.cmake | 2 +-
include/tvm/ir/module.h | 6 +
include/tvm/relax/transform.h | 110 ++++--
python/tvm/ir/module.py | 35 ++
python/tvm/relax/transform/transform.py | 23 ++
src/ir/module.cc | 12 +
.../backend/contrib/codegen_json/codegen_json.h | 419 +++++++++++++++++++++
src/relax/backend/contrib/tensorrt/codegen.cc | 267 +++++++++++++
src/relax/backend/contrib/utils.h | 127 +++++++
src/relax/transform/run_codegen.cc | 190 ++++++++++
tests/python/relax/test_codegen_tensorrt.py | 124 ++++++
tests/python/relax/test_transform_codegen_pass.py | 260 +++++++++++++
12 files changed, 1553 insertions(+), 22 deletions(-)
diff --git a/cmake/modules/contrib/TensorRT.cmake
b/cmake/modules/contrib/TensorRT.cmake
index 696108b501..a749b6e80f 100644
--- a/cmake/modules/contrib/TensorRT.cmake
+++ b/cmake/modules/contrib/TensorRT.cmake
@@ -23,7 +23,7 @@ include (FindPackageHandleStandardArgs)
if(USE_TENSORRT_CODEGEN)
message(STATUS "Build with TensorRT codegen")
- tvm_file_glob(GLOB COMPILER_TENSORRT_SRCS
src/relay/backend/contrib/tensorrt/*.cc)
+ tvm_file_glob(GLOB COMPILER_TENSORRT_SRCS
src/relay/backend/contrib/tensorrt/*.cc src/relax/backend/contrib/tensorrt/*.cc)
set_source_files_properties(${COMPILER_TENSORRT_SRCS} PROPERTIES
COMPILE_FLAGS "-Wno-deprecated-declarations")
tvm_file_glob(GLOB RUNTIME_TENSORRT_SRCS
src/runtime/contrib/tensorrt/tensorrt_runtime.cc)
set_source_files_properties(${RUNTIME_TENSORRT_SRCS} PROPERTIES
COMPILE_FLAGS "-Wno-deprecated-declarations")
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index fdb44b1188..538ff64ca3 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -115,6 +115,12 @@ class IRModuleNode : public Object {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
+ /*!
+ * \brief Get the metadata attributes.
+ * \returns The additional meta-data attributes
+ */
+ DictAttrs GetAttrs() const { return attrs; }
+
/*!
* \brief Check whether the module has an non-zero integer attr.
*
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 7d9f3d64b0..7d6c93bcde 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -25,6 +25,7 @@
#define TVM_RELAX_TRANSFORM_H_
#include <tvm/ir/transform.h>
+#include <tvm/relax/dataflow_pattern.h>
#include <tvm/relax/expr.h>
namespace tvm {
@@ -67,6 +68,13 @@ TVM_DLL Pass CreateDataflowBlockPass(
const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required, bool traceable =
false);
+/*!
+ * \brief Perform lambda lifting to lift functions from nested into global.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass LambdaLift();
+
/*!
* \brief Transform all dataflow structure to non-dataflow version.
*
@@ -105,27 +113,20 @@ TVM_DLL Pass RewriteDataflowReshape();
TVM_DLL Pass StaticPlanBlockMemory();
/*!
- * \brief Bind params of function of the module to constant tensors.
- *
- * \param func_name The name of the function to bind parameters.
- * \param params The parameters to bind.
+ * \brief Attach global_symbol to Relax functions and TIR Primfuncs for
codegen.
*
* \return The Pass.
*/
-TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray>
params);
+TVM_DLL Pass AttachGlobalSymbol();
/*!
- * \brief Fold constant expressions.
- *
- * \return The Pass.
- */
-TVM_DLL Pass FoldConstant();
-/*!
- * \brief Attach global_symbol to Relax functions and TIR Primfuncs for
codegen.
+ * \brief Transform Relax IR to normal form: transform AST to A-normal form,
and fill the
+ * checked_type_ and shape_ of expressions.
*
* \return The Pass.
*/
-TVM_DLL Pass AttachGlobalSymbol();
+TVM_DLL Pass Normalize();
+
/*!
* \brief Bind params of function of the module to constant tensors.
*
@@ -143,14 +144,6 @@ TVM_DLL Pass BindParams(String func_name, Map<String,
runtime::NDArray> params);
*/
TVM_DLL Pass FoldConstant();
-/*!
- * \brief Transform Relax IR to normal form: transform AST to A-normal form,
and fill the
- * checked_type_ and shape_ of expressions.
- *
- * \return The Pass.
- */
-TVM_DLL Pass Normalize();
-
/*!
* \brief Legalize high-level operator calls in Relax functions to call_tir
* with corresponding low-level TIR PrimFuncs.
@@ -190,6 +183,81 @@ TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>>
cmap);
*/
TVM_DLL Pass LiftTransformParams();
+/*!
+ * \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
+ * \note It is an auto-detect pass for "unscheduled prim_funcs", the
op_pattern will be
+ * "opaque" of we can't detect it. Users can manually annotate the attr
`op_pattern`
+ * to prim_func.
+ * \return The Pass.
+ */
+TVM_DLL Pass AnnotateTIROpPattern();
+
+/*!
+ * \brief This pass groups bindings in a dataflow block of Relax functions and
generates a new
+ * grouped Relax function for each group, according to the fusion algorithm
described in the pass
+ * implementation. By grouping bindings into new Relax functions, we
substitute the bindings in the
+ * function being manipulated into function calls to the new grouped function.
+ *
+ * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each
grouped function.
+ * \param fuse_opt_level The level of fuse optimization.
+ * -1 indicates that the level will be inferred from pass context.
+ * \return The Pass.
+ */
+TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
+
+/*!
+ * \brief Apply pattern matching to each function in the given module, and
group matched
+ * expressions into a new function. The end result is similar to FuseOps, but
fusion is driven
+ * completely by the provided patterns.
+ *
+ * \param pattern_names The name of each pattern. It becomes the value of the
kComposite attribute
+ * of a fused function after successful matching.
+ * \param patterns The patterns to detect. The order of the patterns
determines the order
+ * of priority in which they are matched. Higher-priority patterns should come
earlier in the list.
+ * \param annotate_codegen If true, wrap each created composite function with
another function,
+ * whose body consists only of a call to the composite function, and annotate
the outer function
+ * with kCodegen and kGlobalSymbol attributes. The kCodegen attribute is set
as the prefix of the
+ * corresponding pattern name. For example, "dnnl" if the pattern name is
"dnnl.conv2d_relu".
+ * This must be True if the created composite functions are intended to be
offloaded to
+ * an external backend without using the MergeCompositeFunctions pass.
+ * \return The Pass.
+ */
+TVM_DLL Pass FuseOpsByPattern(const tvm::Array<runtime::String>& pattern_names,
+ const tvm::Array<DFPattern>& patterns, bool
annotate_codegen = false);
+
+/*!
+ * \brief Group one or multiple composite functions created by
FuseOpsByPattern into a new
+ * function. The new function will be annotated with kCodegen and
GlobalSymbol attributes,
+ * and it is intented to be offloaded to an external backend.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass MergeCompositeFunctions();
+
+/*!
+ * \brief Fuse relax sub-function into a larger TIR function if possible.
+ this pass works together with FuseOps to perform operator fusion.
+
+ * \return The Pass.
+ */
+TVM_DLL Pass FuseTIR();
+
+/*!
+ * \brief Remove unused global relax functions in an IRModule.
+ * \param entry_functions list of entry functions
+ * \return The Pass.
+ */
+TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
+
+/*!
+ * \brief Run codegen.
+ * \param target_options pairs of target name and compilation options
+ * \param entry_functions list of entry functions
+ * \return The Pass.
+ */
+TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>>
target_options,
+ Array<runtime::String> entry_functions);
+
} // namespace transform
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 3daffb2640..6a151d5a89 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -15,13 +15,18 @@
# specific language governing permissions and limitations
# under the License.
"""IRModule that holds the functions and type definitions."""
+from __future__ import annotations
+
+from typing import Dict, Union
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Scriptable
+from tvm.runtime.object import Object
from . import _ffi_api
from . import expr as _expr
from . import type as _ty
+from .attrs import DictAttrs
from .base import Node
@@ -286,6 +291,36 @@ class IRModule(Node, Scriptable):
return _ffi_api.Module_WithAttr(self, attr_key, attr_value)
+ def without_attr(self, attr_key: str) -> "IRModule":
+ """Copy the IRModule and remove an attribute key and its associated
value.
+ Parameters
+ ----------
+ attr_key : str
+ The attribute key.
+ Returns
+ -------
+ mod : IRModule
+ A new copy of the IRModule without the attribute
+ """
+
+ return _ffi_api.Module_WithoutAttr(self, attr_key)
+
+ def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) ->
"IRModule":
+ """Copy the IRModule and add the given attribute map to it.
+ Parameters
+ ----------
+ attr_map: Union[DictAttrs, Dict[str, Object]]
+ The attribute map
+ Returns
+ -------
+ mod : IRModule
+ A new copy of the IRModule with the attribute
+ """
+ if isinstance(attr_map, tvm.ir.DictAttrs):
+ attr_map = attr_map._dict()
+
+ return _ffi_api.Module_WithAttrs(self, attr_map)
+
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 590059739c..9fb2458dc0 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -188,6 +188,29 @@ def RemoveUnusedFunctions(entry_functions:
Optional[List[str]] = None) -> tvm.ir
return _ffi_api.RemoveUnusedFunctions(entry_functions) # type: ignore
+def RunCodegen(
+ target_options: Optional[dict] = None,
+ entry_functions: Optional[List[str]] = None,
+) -> tvm.ir.transform.Pass:
+ """Produce the runtime::Module with an annotated codegen and global symbol.
+
+ Parameters
+ ----------
+ target_options: Optional[dict]
+ Pairs of a target name and compilation options
+ entry_functions: Optional[List[str]]
+ The set of entry functions to start from.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass to remove unused functions.
+ """
+ if entry_functions is None:
+ entry_functions = ["main"]
+ return _ffi_api.RunCodegen(target_options, entry_functions) # type: ignore
+
+
def FoldConstant() -> tvm.ir.transform.Pass:
"""Fold constant expressions.
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 4a09bdaaf7..8f23f19d35 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -431,11 +431,23 @@
TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, S
mod->ImportFromStd(path);
});
+TVM_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) ->
ObjectRef {
+ return mod->GetAttrs();
+});
+
TVM_REGISTER_GLOBAL("ir.Module_WithAttr")
.set_body_typed([](IRModule mod, String key, ObjectRef value) -> IRModule {
return WithAttr(mod, key, value);
});
+TVM_REGISTER_GLOBAL("ir.Module_WithoutAttr")
+ .set_body_typed([](IRModule mod, String key) -> IRModule { return
WithoutAttr(mod, key); });
+
+TVM_REGISTER_GLOBAL("ir.Module_WithAttrs")
+ .set_body_typed([](IRModule mod, Map<String, ObjectRef> attr_map) ->
IRModule {
+ return WithAttrs(mod, attr_map);
+ });
+
TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod,
String key) -> ObjectRef {
return mod->GetAttr<ObjectRef>(key);
});
diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h
b/src/relax/backend/contrib/codegen_json/codegen_json.h
new file mode 100644
index 0000000000..2197998707
--- /dev/null
+++ b/src/relax/backend/contrib/codegen_json/codegen_json.h
@@ -0,0 +1,419 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/backend/contrib/codegen_json/codegen_json.h
+ * \brief Utilities for json codegen and runtime
+ */
+#ifndef TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_
+#define TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_
+
+#include <dmlc/any.h>
+#include <dmlc/json.h>
+#include <tvm/node/reflection.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/tir/op.h>
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "../../../../runtime/contrib/json/json_node.h"
+#include "../../../../runtime/contrib/json/json_runtime.h"
+#include "../../../transform/utils.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relax {
+namespace backend {
+namespace contrib {
+
+using namespace tvm::runtime::json;
+
+using ShapeVector = std::vector<std::vector<int64_t>>;
+using TypeVector = std::vector<std::string>;
+using JSONGraphObjectPtr = std::shared_ptr<JSONGraphNode>;
+
+/*!
+ * \brief Helper class to extract all attributes of a certain op and save them
+ * into text format.
+ */
+class OpAttrExtractor : public AttrVisitor {
+ public:
+ explicit OpAttrExtractor(JSONGraphObjectPtr node) : node_(node) {}
+
+ template <typename T = double, typename =
std::enable_if_t<std::is_floating_point<T>::value>>
+ std::string Fp2String(const T value) {
+ std::ostringstream out;
+ out.precision(std::numeric_limits<T>::max_digits10);
+ out << value;
+ return out.str();
+ }
+
+ void SetNodeAttr(const char* key, const std::vector<std::string>& value) {
+ std::vector<dmlc::any> attr;
+ attr.emplace_back(value);
+ node_->SetAttr(key, attr);
+ }
+
+ void Visit(const char* key, double* value) final { SetNodeAttr(key,
{Fp2String(*value)}); }
+
+ void Visit(const char* key, int64_t* value) final { SetNodeAttr(key,
{std::to_string(*value)}); }
+
+ void Visit(const char* key, uint64_t* value) final { SetNodeAttr(key,
{std::to_string(*value)}); }
+
+ void Visit(const char* key, int* value) final { SetNodeAttr(key,
{std::to_string(*value)}); }
+
+ void Visit(const char* key, bool* value) final { SetNodeAttr(key,
{std::to_string(*value)}); }
+
+ void Visit(const char* key, std::string* value) final { SetNodeAttr(key,
{*value}); }
+
+ void Visit(const char* key, DataType* value) final {
+ if (!value->is_void()) {
+ SetNodeAttr(key, {runtime::DLDataType2String(*value)});
+ } else {
+ SetNodeAttr(key, {""});
+ }
+ }
+
+ void Visit(const char* key, runtime::ObjectRef* value) final {
+ if (const auto* an = (*value).as<ArrayNode>()) {
+ std::vector<std::string> attr;
+ for (size_t i = 0; i < an->size(); ++i) {
+ if (const auto* im = (*an)[i].as<IntImmNode>()) {
+ attr.push_back(std::to_string(im->value));
+ } else if (const auto* fm = (*an)[i].as<FloatImmNode>()) {
+ attr.push_back(Fp2String(fm->value));
+ } else if (const auto* str = (*an)[i].as<StringObj>()) {
+ String s = GetRef<String>(str);
+ attr.push_back(s);
+ } else {
+ LOG(FATAL) << "Not supported type: " << (*an)[i]->GetTypeKey();
+ }
+ }
+ SetNodeAttr(key, attr);
+ } else if (!(*value).defined()) { // Skip NullValue
+ SetNodeAttr(key, std::vector<std::string>{""});
+ } else if (const auto* im = (*value).as<IntImmNode>()) {
+ SetNodeAttr(key, std::vector<std::string>{std::to_string(im->value)});
+ } else if (const auto* fm = (*value).as<FloatImmNode>()) {
+ SetNodeAttr(key, std::vector<std::string>{Fp2String(fm->value)});
+ } else if (const auto* str = (*value).as<StringObj>()) {
+ String s = GetRef<String>(str);
+ SetNodeAttr(key, std::vector<std::string>{s});
+ } else {
+ LOG(FATAL) << "Not yet supported type: " << (*value)->GetTypeKey() << ":
" << *value;
+ }
+ }
+
+ void Visit(const char* key, runtime::NDArray* value) final {
+ LOG(FATAL) << "NDArray is not allowed in op attribute";
+ }
+
+ void Visit(const char* key, void** value) final {
+ LOG(FATAL) << "void pointer is not allowed in op attribute";
+ }
+
+ void Extract(Object* node) {
+ if (node) {
+ reflection_->VisitAttrs(node, this);
+ }
+ }
+
+ private:
+ JSONGraphObjectPtr node_;
+ ReflectionVTable* reflection_ = ReflectionVTable::Global();
+};
+
+using NodeEntries = std::vector<JSONGraphNodeEntry>;
+
+/*! \brief Serialize a Relax expression to JSON. */
+class JSONSerializer : public relax::MemoizedExprTranslator<NodeEntries> {
+ public:
+ using MemoizedExprTranslator<NodeEntries>::VisitExpr_;
+ using MemoizedExprTranslator<NodeEntries>::VisitBinding_;
+
+ /*!
+ * \brief Constructor
+ * \param constant_names The names of all constants in the original module.
+ */
+ explicit JSONSerializer(const Map<Constant, String>& constant_names)
+ : constant_names_(constant_names) {}
+
+ void serialize(Function func) {
+ // First we convert all the parameters into input nodes.
+ for (const auto& param : func->params) {
+ auto node_ptr = std::make_shared<JSONGraphNode>(param->name_hint(),
"input" /* op_type_ */);
+ memo_[param] = AddNode(node_ptr, param);
+ }
+ heads_ = VisitExpr(func->body);
+ }
+
+ /*!\brief Return the required constants. */
+ Array<String> GetConstantNames() const { return constants_used_; }
+
+ /*!\brief Return the generated json. */
+ std::string GetJSON() {
+ std::ostringstream os;
+ dmlc::JSONWriter writer(&os);
+ Save(&writer);
+ return os.str();
+ }
+
+ protected:
+ /*!
+ * \brief Add a node to graph.
+ *
+ * \param node A graph node. It is a shared pointer. Some attributes of it
+ * will be added, i.e. shape and type. These attributes are attached
to
+ * the JSON graph in the end.
+ * \param expr The relax expression.
+ * \return A list of graph entry nodes. It the relax expr is a tuple type, we
+ * will flatten it.
+ */
+ NodeEntries AddNode(JSONGraphObjectPtr node, const Expr& expr) {
+ auto struct_info = GetStructInfo(expr);
+ auto node_id = nodes_.size();
+ nodes_.push_back(node);
+ NodeEntries ret;
+ ShapeVector shape;
+ TypeVector dtype;
+
+ // Flatten tuple node.
+ if (const auto* tuple_sinfo = struct_info.as<TupleStructInfoNode>()) {
+ for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
+ const auto* tensor_sinfo =
tuple_sinfo->fields[i].as<TensorStructInfoNode>();
+ ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: ."
+ << tuple_sinfo->fields[i]->GetTypeKey();
+ ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined.";
+ ShapeExpr output_shape =
Downcast<ShapeExpr>(tensor_sinfo->shape.value());
+ ret.push_back(JSONGraphNodeEntry(node_id, i));
+ shape.emplace_back(GetIntShape(output_shape->values));
+ dtype.emplace_back(DType2String(tensor_sinfo->dtype));
+ }
+ node->SetNumOutput(tuple_sinfo->fields.size());
+ } else {
+ const auto* tensor_sinfo = struct_info.as<TensorStructInfoNode>();
+ ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: "
+ << struct_info->GetTypeKey();
+ ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined.";
+ ShapeExpr output_shape =
Downcast<ShapeExpr>(tensor_sinfo->shape.value());
+
+ shape.emplace_back(GetIntShape(output_shape->values));
+ dtype.emplace_back(DType2String(tensor_sinfo->dtype));
+ ret.push_back(JSONGraphNodeEntry(node_id, 0));
+ }
+ std::vector<dmlc::any> shape_attrs;
+ shape_attrs.emplace_back(shape);
+ node->SetAttr("shape", shape_attrs);
+
+ std::vector<dmlc::any> type_attrs;
+ type_attrs.emplace_back(dtype);
+ node->SetAttr("dtype", type_attrs);
+ return ret;
+ }
+
+ void SetCallNodeAttribute(JSONGraphObjectPtr node, const CallNode* cn) {
+ if (cn->op.as<OpNode>()) {
+ OpAttrExtractor extractor(node);
+ const Object* call_attr = cn->attrs.get();
+ extractor.Extract(const_cast<Object*>(call_attr));
+ } else if (const auto* fn = cn->op.as<FunctionNode>()) {
+ ICHECK(false);
+ auto pattern = fn->GetAttr<String>(attr::kPartitionedFromPattern);
+ ICHECK(pattern.defined());
+ std::vector<std::string> values;
+ values.push_back(pattern.value());
+ std::vector<dmlc::any> attr;
+ attr.emplace_back(values);
+ node->SetAttr("PartitionedFromPattern", attr);
+ }
+ }
+
+ NodeEntries VisitBinding_(const MatchCastNode* binding) {
+ LOG(FATAL) << "JSON runtime currently doesn't match cast\n";
+ return {};
+ }
+
+ NodeEntries VisitBinding(const Binding& binding) {
+ NodeEntries nodes;
+ if (const auto* node = binding.as<VarBindingNode>()) {
+ auto from_b = VisitBinding_(node);
+ nodes.insert(nodes.end(), from_b.begin(), from_b.end());
+ } else if (const auto* node = binding.as<MatchCastNode>()) {
+ auto from_b = VisitBinding_(node);
+ nodes.insert(nodes.end(), from_b.begin(), from_b.end());
+ } else {
+ LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey();
+ }
+ return nodes;
+ }
+
+ NodeEntries VisitBindingBlock(const BindingBlock& block) {
+ NodeEntries nodes;
+ if (const auto* node = block.as<DataflowBlockNode>()) {
+ auto from_bb = VisitBindingBlock_(node);
+ nodes.insert(nodes.end(), from_bb.begin(), from_bb.end());
+ } else if (const auto* node = block.as<BindingBlockNode>()) {
+ auto from_bb = VisitBindingBlock_(node);
+ nodes.insert(nodes.end(), from_bb.begin(), from_bb.end());
+ } else {
+ LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey();
+ }
+ return nodes;
+ }
+
+ NodeEntries VisitBindingBlock_(const BindingBlockNode* block) {
+ NodeEntries nodes;
+ for (Binding binding : block->bindings) {
+ auto from_b = VisitBinding(binding);
+ nodes.insert(nodes.end(), from_b.begin(), from_b.end());
+ }
+ return nodes;
+ }
+
+ NodeEntries VisitBindingBlock_(const DataflowBlockNode* block) {
+ NodeEntries nodes;
+ for (Binding binding : block->bindings) {
+ auto from_b = VisitBinding(binding);
+ nodes.insert(nodes.end(), from_b.begin(), from_b.end());
+ }
+ return nodes;
+ }
+
+ NodeEntries VisitExpr_(const SeqExprNode* op) {
+ NodeEntries nodes;
+ for (BindingBlock block : op->blocks) {
+ VisitBindingBlock(block);
+ }
+ auto from_body = VisitExpr(op->body);
+ nodes.insert(nodes.end(), from_body.begin(), from_body.end());
+ return nodes;
+ }
+
+ NodeEntries VisitExprDefault_(const Object* op) {
+ LOG(FATAL) << "JSON runtime currently doesn't support " <<
op->GetTypeKey();
+ return {};
+ }
+
+ NodeEntries VisitExpr_(const ConstantNode* cn) {
+ auto name = constant_names_.find(GetRef<Constant>(cn));
+ ICHECK(name != constant_names_.end())
+ << "Cannot find the name of the constant: " << GetRef<Constant>(cn);
+ constants_used_.push_back((*name).second);
+ auto node = std::make_shared<JSONGraphNode>((*name).second, "const" /*
op_type_ */);
+ return AddNode(node, GetRef<Expr>(cn));
+ }
+
+ NodeEntries VisitExpr_(const TupleNode* tn) {
+ NodeEntries fields;
+ for (const auto& field : tn->fields) {
+ auto ref = VisitExpr(field);
+ fields.insert(fields.end(), ref.begin(), ref.end());
+ }
+ return fields;
+ }
+
+ NodeEntries VisitExpr_(const CallNode* cn) {
+ Expr expr = GetRef<Expr>(cn);
+ std::string name;
+ if (const auto* op_node = cn->op.as<OpNode>()) {
+ name = op_node->name;
+ } else if (const auto* fn = cn->op.as<FunctionNode>()) {
+ auto comp = fn->GetAttr<String>(attr::kComposite);
+ ICHECK(comp.defined()) << "JSON runtime only supports composite
functions.";
+ name = comp.value();
+ } else {
+ LOG(FATAL) << "JSON runtime does not support calls to " <<
cn->op->GetTypeKey();
+ }
+
+ // TODO(@sunggg): Revisit when we have op naming convention.
+ // Currently, simply remove "relax." prefix to make it work.
+ name = std::string("tensorrt.") + name.substr(6);
+
+ NodeEntries inputs;
+ for (const auto& arg : cn->args) {
+ auto res = VisitExpr(arg);
+ inputs.insert(inputs.end(), res.begin(), res.end());
+ }
+ auto node = std::make_shared<JSONGraphNode>(name, /* name_ */
+ "kernel", /* op_type_ */
+ inputs, 1 /* num_outputs_ */);
+ SetCallNodeAttribute(node, cn);
+ return AddNode(node, GetRef<Expr>(cn));
+ }
+
+ NodeEntries VisitExpr_(const TupleGetItemNode* gtn) {
+ auto vtuple = VisitExpr(gtn->tuple);
+ return {vtuple[gtn->index]};
+ }
+
+ NodeEntries VisitExpr_(const FunctionNode* fn) {
+ ICHECK(fn->GetAttr<String>(attr::kComposite).defined())
+ << "JSON runtime only supports composite functions";
+
+ // FunctionNode should be handled by the caller.
+ return {};
+ }
+
+ /*!
+ * \brief Save to JSON graph
+ *
+ * \param writer A json writer
+ */
+ void Save(dmlc::JSONWriter* writer) {
+ std::vector<size_t> arg_nodes;
+ for (size_t i = 0; i < nodes_.size(); ++i) {
+ auto node = nodes_[i];
+ if (node->IsLeaf()) {
+ arg_nodes.push_back(i);
+ }
+ }
+ size_t num_entry = 0;
+ std::vector<size_t> node_row_ptr{0};
+ for (auto node : nodes_) {
+ num_entry += node->GetNumOutput();
+ node_row_ptr.push_back(num_entry);
+ }
+ writer->BeginObject();
+ writer->WriteObjectKeyValue("nodes", nodes_);
+ writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
+ writer->WriteObjectKeyValue("heads", heads_);
+ writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
+ writer->EndObject();
+ }
+
+ private:
+ /*! \brief JSON graph nodes. */
+ std::vector<JSONGraphObjectPtr> nodes_;
+ /*! \brief Output of the JSON graph. */
+ NodeEntries heads_;
+ /*! \brief The list of required constants, ordered. */
+ Array<String> constants_used_;
+ /*! \brief The names of all constants in the original module. */
+ const Map<Constant, String>& constant_names_;
+};
+
+} // namespace contrib
+} // namespace backend
+} // namespace relax
+} // namespace tvm
+#endif // TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_
diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc
b/src/relax/backend/contrib/tensorrt/codegen.cc
new file mode 100644
index 0000000000..5ce6bf5e7d
--- /dev/null
+++ b/src/relax/backend/contrib/tensorrt/codegen.cc
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/backend/contrib/tensorrt/codegen.cc
+ * \brief Implementation of the TensorRT JSON serializer.
+ */
+#include <tvm/ir/module.h>
+// TODO(sunggg): add operator attribute when it's ready
+// #include <tvm/relax/attrs/nn.h>
+#include <tvm/relax/type.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "../../../transform/utils.h"
+#include "../codegen_json/codegen_json.h"
+#include "../utils.h"
+
+#if TVM_GRAPH_EXECUTOR_TENSORRT
+#include "NvInfer.h"
+#endif
+
+namespace tvm {
+namespace relax {
+namespace contrib {
+
+/*! \brief Attributes to store the compiler options for TensorRT. */
+struct TensorRTCompilerConfigNode : public
tvm::AttrsNode<TensorRTCompilerConfigNode> {
+ Array<Integer> tensorrt_version;
+ bool use_implicit_batch;
+ size_t max_workspace_size;
+ bool remove_no_mac_subgraphs;
+ bool use_fp16;
+ bool use_uint8;
+
+ TVM_DECLARE_ATTRS(TensorRTCompilerConfigNode,
"relax.ext.attrs.TensorRTCompilerConfigNode") {
+ TVM_ATTR_FIELD(tensorrt_version)
+ .describe("TensorRT version as (major, minor, patch).")
+ .set_default(Array<Integer>({6, 0, 1}));
+ TVM_ATTR_FIELD(use_implicit_batch).set_default(true);
+ TVM_ATTR_FIELD(max_workspace_size).set_default(size_t(1) << 30);
+ TVM_ATTR_FIELD(remove_no_mac_subgraphs).set_default(false);
+ TVM_ATTR_FIELD(use_fp16).set_default(false);
+ TVM_ATTR_FIELD(use_uint8).set_default(false);
+ }
+};
+
+class TensorRTCompilerConfig : public Attrs {
+ public:
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorRTCompilerConfig, Attrs,
+ TensorRTCompilerConfigNode);
+};
+
+TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("relax.ext.tensorrt.options",
TensorRTCompilerConfig);
+
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr;
+using OpAttrExtractor = backend::contrib::OpAttrExtractor;
+using JSONSerializer = backend::contrib::JSONSerializer;
+
+class TensorRTJSONSerializer;
+
+/*!
+ * \brief Collect the constants and attributes from all operator calls in the
body
+ * of a "Composite" function.
+ */
+class CollectFromCompositeFunctionBody : public ExprVisitor {
+ public:
+ explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer)
+ : serializer_(serializer), node_(std::make_shared<JSONGraphNode>()) {}
+
+ void VisitExpr_(const ConstantNode* constant_node) final;
+ void VisitExpr_(const CallNode* call_node) final;
+
+ void SetGenericAttributes(const CallNode* call_node) {
+ OpAttrExtractor extractor(node_);
+ const Object* attr_obj = call_node->attrs.get();
+ extractor.Extract(const_cast<Object*>(attr_obj));
+ }
+
+ TensorRTJSONSerializer* serializer_;
+ /*! \brief Accumulated translated arguments. */
+ std::vector<JSONGraphNodeEntry> args_;
+ /*!
+ * \brief Temporary node into which we'll accumulate attributes. Ideally
this would be the
+ * final JSONGraphNode however we don't yet know how many inputs that will
have.
+ */
+ JSONGraphObjectPtr node_;
+};
+
+/*!
+ * \brief Generates an TensorRTModule from a relax expression by serializing
the expression to a
+ * json representation. TensorRT is not required here because use of TensorRT
APIs is deferred until
+ * runtime.
+ */
+class TensorRTJSONSerializer : public JSONSerializer {
+ public:
+ explicit TensorRTJSONSerializer(Map<Constant, String> constant_names,
Map<Var, Expr> bindings)
+ : JSONSerializer(constant_names), bindings_(bindings) {}
+
+ using JSONSerializer::VisitExpr_;
+
+ std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* call_node) final {
+ // The call must be to an inline "Composite" function
+ const auto* fn_var = call_node->op.as<VarNode>();
+ ICHECK(fn_var);
+ const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
+
+ auto opt_composite = fn->GetAttr<String>(attr::kComposite);
+ ICHECK(opt_composite.defined());
+ std::string name = opt_composite.value();
+
+ // Collect the constants and attributes of all operator calls inside the
composite body.
+ CollectFromCompositeFunctionBody collector(this);
+ collector.VisitExpr(fn->body);
+
+ // Capture the args to the "Composite" function as inputs for this node.
+ std::vector<JSONGraphNodeEntry> inputs;
+ for (const auto& arg : call_node->args) {
+ auto res = VisitExpr(arg);
+ inputs.insert(inputs.end(), res.begin(), res.end());
+ }
+
+ // Capture constants from the composite function body as additional inputs
for this node.
+ for (const auto& node : collector.args_) {
+ inputs.emplace_back(node);
+ }
+
+ // Create the final node.
+ auto node = std::make_shared<JSONGraphNode>(name,
+ /*op_type=*/"kernel", inputs,
+ /*num_output=*/1);
+
+ // Transfer attributes from the collector's node to the final node.
+ node->CaptureAttrs(*collector.node_);
+
+ // Capture global settings on the JSON node.
+ SaveGlobalAttributes(node);
+
+ VLOG(1) << name << " has " << node->GetInputs().size() << " inputs";
+
+ return AddNode(node, GetRef<Expr>(call_node));
+ }
+
+ static void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
+ auto ctx = transform::PassContext::Current();
+ auto cfg =
ctx->GetConfig<TensorRTCompilerConfig>("relax.ext.tensorrt.options");
+ if (!cfg.defined()) {
+ cfg = AttrsWithDefaultValues<TensorRTCompilerConfig>();
+ }
+ ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3);
+ std::vector<std::string> tensorrt_version = {
+ std::to_string(cfg.value()->tensorrt_version[0].IntValue()),
+ std::to_string(cfg.value()->tensorrt_version[1].IntValue()),
+ std::to_string(cfg.value()->tensorrt_version[2].IntValue())};
+ std::vector<std::string> use_implicit_batch =
{std::to_string(cfg.value()->use_implicit_batch)};
+ std::vector<std::string> max_workspace_size =
{std::to_string(cfg.value()->max_workspace_size)};
+ std::vector<std::string> use_fp16 =
{std::to_string(cfg.value()->use_fp16)};
+ std::vector<std::string> use_uint8 =
{std::to_string(cfg.value()->use_uint8)};
+ std::vector<dmlc::any> tensorrt_version_attr, use_implicit_batch_attr,
max_workspace_size_attr,
+ use_fp16_attr, use_uint8_attr;
+ tensorrt_version_attr.emplace_back(tensorrt_version);
+ use_implicit_batch_attr.emplace_back(use_implicit_batch);
+ max_workspace_size_attr.emplace_back(max_workspace_size);
+ use_fp16_attr.emplace_back(use_fp16);
+ use_uint8_attr.emplace_back(use_uint8);
+ node->SetAttr("tensorrt_version", tensorrt_version_attr);
+ node->SetAttr("use_implicit_batch", use_implicit_batch_attr);
+ node->SetAttr("max_workspace_size", max_workspace_size_attr);
+ node->SetAttr("use_fp16", use_fp16_attr);
+ node->SetAttr("use_uint8", use_uint8_attr);
+ }
+
+ private:
+ /*! \brief The bindings to look up composite functions. */
+ Map<Var, Expr> bindings_;
+};
+
+void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode*
constant_node) {
+ for (const auto& entry :
serializer_->VisitExpr(GetRef<Constant>(constant_node))) {
+ args_.emplace_back(entry);
+ }
+}
+
+void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
+ SetGenericAttributes(call_node);
+ ExprVisitor::VisitExpr_(call_node);
+}
+
+/*!
+ * \brief Create runtime modules for TensorRT.
+ * \param functions The extern functions to be compiled via TensorRT
+ * \return Runtime modules.
+ */
+Array<runtime::Module> TensorRTCompiler(Array<Function> functions,
+ Map<String, ObjectRef> /*unused*/,
+ Map<Constant, String> constant_names) {
+ Array<runtime::Module> compiled_functions;
+ for (const auto& func : functions) {
+ VLOG(1) << "TensorRT partition:" << std::endl << func;
+ TensorRTJSONSerializer serializer(constant_names, AnalyzeVar2Value(func));
+ serializer.serialize(func);
+ std::string graph_json = serializer.GetJSON();
+ VLOG(1) << "TensorRT JSON:" << std::endl << graph_json;
+ auto constant_names = serializer.GetConstantNames();
+ const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
+ ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create
function.";
+ std::string func_name = GetExtSymbol(func);
+ VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'";
+ compiled_functions.push_back((*pf)(func_name, graph_json, constant_names));
+ }
+ return compiled_functions;
+}
+
+TVM_REGISTER_GLOBAL("relax.ext.tensorrt").set_body_typed(TensorRTCompiler);
+
+/*!
+ * \brief Check whether TensorRT graph executor is enabled.
+ * \return True if enabled, False if not.
+ */
+inline constexpr bool IsTensorRTRuntimeEnabled() {
+#if TVM_GRAPH_EXECUTOR_TENSORRT
+ return true;
+#else
+ return false;
+#endif // TVM_GRAPH_EXECUTOR_TENSORRT
+}
+
+/*!
+ * \brief Get TensorRT version that TVM is built against.
+ * \return Array of three integers for major, minor, and patch, or empty array
if TensorRT graph
+ * runtime is not enabled.
+ */
+Array<Integer> GetTensorRTVersion() {
+#if TVM_GRAPH_EXECUTOR_TENSORRT
+ return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR),
Integer(NV_TENSORRT_PATCH)};
+#else
+ return {};
+#endif // TVM_GRAPH_EXECUTOR_TENSORRT
+}
+
+TVM_REGISTER_GLOBAL("relax.is_tensorrt_runtime_enabled").set_body_typed(IsTensorRTRuntimeEnabled);
+TVM_REGISTER_GLOBAL("relax.get_tensorrt_version").set_body_typed(GetTensorRTVersion);
+
+} // namespace contrib
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/backend/contrib/utils.h
b/src/relax/backend/contrib/utils.h
new file mode 100644
index 0000000000..4190ad66b6
--- /dev/null
+++ b/src/relax/backend/contrib/utils.h
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/backend/contrib/utils.h
+ * \brief Utils function for backend
+ */
+#ifndef TVM_RELAX_BACKEND_CONTRIB_UTILS_H_
+#define TVM_RELAX_BACKEND_CONTRIB_UTILS_H_
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+
+#include <string>
+#include <vector>
+
+#include "../../transform/utils.h"
+
+namespace tvm {
+namespace relax {
+namespace backend {
+
+/*!
+ * \brief Get the Packed Func
+ *
+ * \param func_name
+ * \return const PackedFunc*
+ */
+inline const PackedFunc* GetPackedFunc(const std::string& func_name) {
+ return tvm::runtime::Registry::Get(func_name);
+}
+
+/*!
+ * \brief Extract shape from an IndexExpr array to std::vector<int64_t>
+ *
+ * \param shape The shape in Array
+ * \return The converted shape in std::vector<int64_t>
+ */
+
+inline std::vector<int64_t> GetIntShape(const Array<PrimExpr>& shape) {
+ std::vector<int64_t> ret;
+ for (const auto& dim : shape) {
+ const int64_t* pval = tir::as_const_int(dim);
+ ret.push_back(pval ? *pval : -1);
+ }
+ return ret;
+}
+
+/*!
+ * \brief Convert type to string
+ *
+ * \param typ
+ * \return std::string string format of type
+ */
+inline std::string DType2String(const tvm::DataType dtype) {
+ std::ostringstream os;
+ if (dtype.is_float()) {
+ os << "float";
+ } else if (dtype.is_int()) {
+ os << "int";
+ } else if (dtype.is_uint()) {
+ os << "uint";
+ } else if (dtype.is_bfloat16()) {
+ os << "bfloat";
+ } else if
((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) {
+ os << "custom["
+ <<
(*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator
std::string()
+ << "]";
+ } else {
+ LOG(FATAL) << "Unknown type with code " <<
static_cast<unsigned>(dtype.code());
+ }
+ os << dtype.bits();
+ return os.str();
+}
+
+/*!
+ * \brief Check if a call node is calling an op with the given name
+ * \param call The call node whose callee we want to check
+ * \param op_name The name of the op
+ * \return true if the callee op matches with the op name
+ */
+inline bool IsOp(const CallNode* call, const std::string& op_name) {
+ const auto* op_node = call->op.as<OpNode>();
+ if (!op_node) return false;
+ Op op = GetRef<Op>(op_node);
+ return op == Op::Get(op_name);
+}
+
+/*!
+ * \brief Return a call node within the function which calls an op with the
given name
+ * The function must contain exactly one call to such op.
+ * \param f The function to look for an op.
+ * \param op_name The name of the op
+ * \return A call node which calls an op with the given name
+ */
+inline const CallNode* GetOpInFunction(Function f, const std::string& op_name)
{
+ auto local_bindings = AnalyzeVar2Value(f);
+ for (const auto& entry : local_bindings) {
+ if (auto call = entry.second.as<CallNode>(); call && backend::IsOp(call,
op_name)) {
+ return call;
+ }
+ }
+ LOG(FATAL) << op_name << " not found in the function:\n" << f;
+ return nullptr;
+}
+
+} // namespace backend
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_BACKEND_CONTRIB_UTILS_H_
diff --git a/src/relax/transform/run_codegen.cc
b/src/relax/transform/run_codegen.cc
new file mode 100644
index 0000000000..114b7d2a34
--- /dev/null
+++ b/src/relax/transform/run_codegen.cc
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/run_codegen.cc
+ * \brief Run codegen for annotated relax functions.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+
+#include <iostream>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+class CodeGenRunner : ExprMutator {
+ public:
+ using OptionMap = Map<String, ObjectRef>;
+
+ explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {}
+
+ IRModule Run(Optional<Map<String, OptionMap>> target_options, Array<String>
entry_functions) {
+ IRModule mod = builder_->GetContextIRModule();
+ for (const String& entry_func_name : entry_functions) {
+ auto entry_func = mod->Lookup(entry_func_name);
+ auto gvar = mod->GetGlobalVar(entry_func_name);
+ builder_->UpdateFunction(gvar,
Downcast<BaseFunc>(VisitExpr(entry_func)));
+ }
+
+ auto ext_mods = InvokeCodegen(mod, target_options.value_or({}));
+ auto out_mod = builder_->GetContextIRModule();
+
+ if (ext_mods.size()) {
+ out_mod = WithAttr(out_mod, tvm::attr::kExternalMods,
std::move(ext_mods));
+ }
+
+ if (constant_names.size()) {
+ // Some backends (e.g. TensorRT) expect constants to be passed when they
are instantiated
+ Map<String, runtime::NDArray> constants;
+ for (const auto& [constant, name] : constant_names) {
+ ICHECK(!constants.count(name)) << "More than one constant with the
name " << name;
+ constants.Set(name, constant->data);
+ }
+ out_mod = WithAttr(out_mod, tvm::attr::kConstNameToConstant,
std::move(constants));
+ }
+
+ // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a
better way to handle this.
+ return RemoveUnusedFunctions(out_mod, entry_functions);
+ }
+
+ using ExprMutator::VisitExpr_;
+
+ Expr VisitExpr_(const CallNode* call_node) override {
+ auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
+ if (auto const* gvar_node = call_node->op.as<GlobalVarNode>()) {
+ const GlobalVar gvar = GetRef<GlobalVar>(gvar_node);
+
+ auto create_call_tir = [call_node, this](Expr extern_func, StructInfo
ret_struct_info) {
+ Array<Expr> new_args({extern_func});
+ new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return
VisitExpr(arg); })));
+
+ static const Op& call_op = Op::Get("relax.call_tir");
+
+ return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info});
+ };
+
+ if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
+ return create_call_tir(it->second.first, it->second.second);
+ } else {
+ // TODO(@sunggg): Is there any better way to get this func?
+ Function func =
Downcast<Function>(builder_->GetContextIRModule()->Lookup(gvar));
+ Expr new_func = VisitExpr(func);
+
+ if (new_func->IsInstance<ExternFuncNode>()) {
+ extern_funcs_[gvar_node] = {new_func, func->ret_struct_info};
+ // Remove the global symbol and codegen attributes from the function
so that it can be
+ // removed the module.
+ static const runtime::PackedFunc* RemoveFuncAttrFunc =
+ runtime::Registry::Get("ir.BaseFuncWithoutAttr");
+ ICHECK(RemoveFuncAttrFunc);
+ func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol);
+ func = (*RemoveFuncAttrFunc)(func, attr::kCodegen);
+ builder_->UpdateFunction(gvar, func);
+ return create_call_tir(new_func, func->ret_struct_info);
+ }
+ }
+ }
+ Array<Expr> new_args;
+ for (const auto& arg : call_node->args) {
+ new_args.push_back(VisitExpr(arg));
+ }
+
+ return Call(call_node->op, new_args, call_node->attrs,
call_node->sinfo_args, call_node->span);
+ }
+
+ Expr VisitExpr_(const FunctionNode* func_node) override {
+ Function func = GetRef<Function>(func_node);
+ auto opt_codegen = func->GetAttr<String>(attr::kCodegen);
+ if (opt_codegen) {
+ auto ext_symbol = GetExtSymbol(func);
+ size_t count = 0;
+ PostOrderVisit(func->body, [=, &count](Expr e) {
+ if (e->IsInstance<ConstantNode>()) {
+ // Make sure to pick a unique name
+ auto name = ext_symbol + "_" + opt_codegen.value() + "_const_" +
std::to_string(count++);
+ auto constant = Downcast<Constant>(e);
+ constant_names.Set(constant, name);
+ }
+ });
+ return ExternFunc(GetExtSymbol(func));
+ } else {
+ return ExprMutator::VisitExpr_(func_node);
+ }
+ }
+
+ private:
+ Array<runtime::Module> InvokeCodegen(IRModule mod, Map<String, OptionMap>
target_options) {
+ std::unordered_map<std::string, Array<Function>> target_functions;
+
+ for (const auto& entry : mod->functions) {
+ PostOrderVisit(entry.second, [&target_functions](Expr e) {
+ if (e->IsInstance<FunctionNode>()) {
+ auto f = Downcast<Function>(e);
+ if (auto target_opt = f->GetAttr<String>(attr::kCodegen)) {
+ String target = target_opt.value();
+ target_functions[target].push_back(f);
+ }
+ }
+ });
+ }
+
+ Array<runtime::Module> ext_mods;
+
+ for (const auto& [target, functions] : target_functions) {
+ OptionMap options = target_options.Get(target).value_or({});
+ // Start the codegen process.
+ // Get the codegen with its ffi key.
+ String codegen_name = "relax.ext." + target;
+ auto codegen = runtime::Registry::Get(codegen_name);
+ ICHECK(codegen) << "Codegen is not found: " << codegen_name << "\n";
+
+ Array<runtime::Module> compiled_functions = (*codegen)(functions,
options, constant_names);
+ ext_mods.insert(ext_mods.end(), compiled_functions.begin(),
compiled_functions.end());
+ }
+
+ return ext_mods;
+ }
+
+ /*! \brief The names of all constants in the original module. */
+ Map<Constant, String> constant_names;
+ /*! \brief Extern funcs and their return struct infos for each global
variable. */
+ std::unordered_map<const GlobalVarNode*, std::pair<Expr, StructInfo>>
extern_funcs_;
+};
+
+} // namespace relax
+
+namespace transform {
+Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>> target_options,
+ Array<String> entry_functions) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m,
+
PassContext pc) {
+ return relax::CodeGenRunner(m).Run(target_options, entry_functions);
+ };
+ return CreateModulePass(pass_func, 0, "RunCodegen", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen);
+
+} // namespace transform
+} // namespace tvm
diff --git a/tests/python/relax/test_codegen_tensorrt.py
b/tests/python/relax/test_codegen_tensorrt.py
new file mode 100644
index 0000000000..164cf3a818
--- /dev/null
+++ b/tests/python/relax/test_codegen_tensorrt.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import numpy as np
+import tvm
+import tvm.testing
+
+from tvm import relax, relay
+from tvm.script import relax as R
+from tvm.relax.dpl import make_fused_bias_activation_pattern, is_op, wildcard
+
+
+def get_relay_residual_block(d_shape, w_shape):
+ data = relay.var("data", shape=d_shape)
+ weight1 = relay.var("weight1", shape=w_shape)
+ weight2 = relay.var("weight2", shape=w_shape)
+ conv1 = relay.nn.relu(
+ relay.nn.conv2d(
+ data=data,
+ weight=weight1,
+ padding=(1, 1),
+ )
+ )
+ conv2d = relay.nn.relu(
+ relay.nn.conv2d(
+ data=conv1,
+ weight=weight2,
+ padding=(1, 1),
+ )
+ )
+ return conv2d + data
+
+
[email protected]_module
+class Conv2dResidualBlock:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), "float32"),
+ weight1: R.Tensor((64, 64, 3, 3), "float32"),
+ weight2: R.Tensor((64, 64, 3, 3), "float32"),
+ ):
+ with R.dataflow():
+ conv1 = relax.op.nn.relu(relax.op.nn.conv2d(data, weight1,
padding=(1, 1)))
+ conv2 = relax.op.nn.relu(relax.op.nn.conv2d(conv1, weight2,
padding=(1, 1)))
+ out = relax.op.add(conv2, data)
+ R.output(out)
+
+ return out
+
+
+has_tensorrt = tvm.get_global_func("relax.ext.tensorrt", True)
+
+tensorrt_enabled = pytest.mark.skipif(
+ not has_tensorrt,
+ reason="TENSORRT not enabled.",
+)
+
+pytestmark = [tensorrt_enabled]
+
+
+def test_tensorrt_offload():
+ weight1_np = np.random.randn(64, 64, 3, 3).astype("float32")
+ weight2_np = np.random.randn(64, 64, 3, 3).astype("float32")
+
+ conv_pat = make_fused_bias_activation_pattern(
+ "relax.nn.conv2d", with_bias=False, activation=None
+ )
+ relu_pat = is_op("relax.nn.relu")(wildcard())
+ add_pat = is_op("relax.add")(wildcard(), wildcard())
+
+ patterns = [
+ ("tensorrt.nn.conv2d", conv_pat),
+ ("tensorrt.nn.relu", relu_pat),
+ ("tensorrt.add", add_pat),
+ ]
+
+ params_np = {"weight1": weight1_np, "weight2": weight2_np}
+
+ mod = tvm.transform.Sequential(
+ [
+ relax.transform.BindParams("main", params_np),
+ relax.transform.FuseOpsByPattern(patterns),
+ relax.transform.MergeCompositeFunctions(),
+ relax.transform.RunCodegen(),
+ ]
+ )(Conv2dResidualBlock)
+
+ target = "cuda"
+ dev = tvm.device(target, 0)
+ ex = relax.vm.build(mod, target)
+
+ vm = relax.VirtualMachine(ex, dev)
+ f = vm["main"]
+
+ data_np = np.random.randn(1, 64, 56, 56).astype("float32")
+ out = f(tvm.nd.array(data_np, dev)).numpy()
+
+ relay_mod = tvm.IRModule.from_expr(get_relay_residual_block(data_np.shape,
weight1_np.shape))
+
+ ref = (
+ relay.create_executor("graph", mod=relay_mod, device=tvm.cpu(0),
target="llvm")
+ .evaluate()(*[data_np, weight1_np, weight2_np])
+ .numpy()
+ )
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
+
+
+if __name__ == "__main__":
+ test_tensorrt_offload()
diff --git a/tests/python/relax/test_transform_codegen_pass.py
b/tests/python/relax/test_transform_codegen_pass.py
new file mode 100644
index 0000000000..e50ad8f5f4
--- /dev/null
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -0,0 +1,260 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import os
+import tvm
+import tvm.testing
+from tvm import relax
+import numpy as np
+from tvm.script import relax as R
+from tvm.relax.testing import transform
+import tempfile
+from tvm.relax.transform.tuning_api import Trace
+from tvm.relax.dpl import is_op, wildcard
+
+env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True)
+env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled",
True)
+
+has_tensorrt_codegen = pytest.mark.skipif(
+ not env_checker_codegen,
+ reason="TensorRT codegen not available",
+)
+has_tensorrt_runtime = pytest.mark.skipif(
+ not env_checker_runtime or not env_checker_runtime(),
+ reason="TensorRT runtime not available",
+)
+
+# Global variable in pytest that applies markers to all tests.
+pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime]
+
+# Target gpu
+target_str = "nvidia/geforce-rtx-3070" # "nvidia/nvidia-t4"
+target = tvm.target.Target(target_str)
+dev = tvm.cuda()
+
+
+def check_executable(exec, dev, inputs, expected):
+ vm = relax.VirtualMachine(exec, dev)
+ out = vm["main"](*inputs)
+ tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5,
rtol=1e-5)
+
+
+def check_roundtrip(exec0, dev, inputs, expected):
+ exec0.mod.export_library("exec.so")
+ exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so"))
+ os.remove("exec.so")
+ assert exec0.stats() == exec1.stats()
+ assert exec0.as_text() == exec1.as_text()
+
+ check_executable(exec0, dev, inputs, expected)
+ check_executable(exec1, dev, inputs, expected)
+
+
+def gen_ground_truth(mod, target, dev, inputs):
+ # Lower and run tuning
+ # Since there is no default schedule for GPU in MS yet, this is necessary
+ with tempfile.TemporaryDirectory() as work_dir:
+ with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0):
+ seq = tvm.transform.Sequential(
+ [
+ relax.transform.LegalizeOps(),
+ relax.transform.MetaScheduleTuneIRMod(
+ params={}, work_dir=work_dir, max_trials_global=8
+ ),
+ relax.transform.MetaScheduleApplyDatabase(work_dir),
+ ]
+ )
+ new_mod = seq(mod)
+ assert relax.analysis.well_formed(new_mod)
+ exec = relax.vm.build(new_mod, target, params={})
+ vm = relax.VirtualMachine(exec, dev)
+ return vm["main"](*inputs)
+
+
[email protected]_module
+class InputModule:
+ @R.function
+ def main(
+ x: R.Tensor((16, 16), "float32"), y: R.Tensor((16, 16), "float32")
+ ) -> R.Tensor((16, 16), "float32"):
+ with R.dataflow():
+ z1 = R.multiply(x, y)
+ z2 = R.add(z1, x)
+ z3 = R.add(z1, z2)
+ z4 = R.multiply(z3, z2)
+ z5 = R.add(z4, z1)
+ R.output(z5)
+ return z5
+
+
+def setup_test():
+ # Prepare IRModule and its input
+ mod = InputModule
+ assert isinstance(mod, tvm.IRModule)
+
+ np0 = np.random.rand(16, 16).astype(np.float32)
+ np1 = np.random.rand(16, 16).astype(np.float32)
+ data0 = tvm.nd.array(np0, dev)
+ data1 = tvm.nd.array(np1, dev)
+ inputs = [data0, data1]
+
+ # Ground truth should be generated before annotation
+ # due to the conflict with MS task extraction
+ # TODO(@sunggg): Sort this out
+ expected = gen_ground_truth(mod, target, dev, inputs)
+ return mod, inputs, expected
+
+
[email protected]_gpu
+def test_tensorrt_only():
+ mod, inputs, expected = setup_test()
+
+ # Define patterns that we want to offload to byoc
+ # This test will offload entire model
+ # Thus, define patterns for both `multiply` and `add` ops
+ patterns = [
+ ("tensorrt.multiply", is_op("relax.multiply")(wildcard(), wildcard())),
+ ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())),
+ ]
+
+ new_mod = tvm.transform.Sequential(
+ [
+ relax.transform.FuseOpsByPattern(patterns),
+ relax.transform.MergeCompositeFunctions(),
+ relax.transform.RunCodegen(),
+ ]
+ )(mod)
+
+ ex0 = relax.vm.build(new_mod, target, params={})
+ # Sanity check for the correctness and rountrip
+ check_roundtrip(ex0, dev, inputs, expected)
+
+
[email protected]_gpu
+def test_mix_use_tensorrt_and_tvm():
+ mod, inputs, expected = setup_test()
+
+ # Define patterns that we want to offload to byoc
+ # This test will only offload `add` op to tensorrt
+ # and tune `multiply` op with MetaSchedule
+ patterns = [
+ ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())),
+ ]
+
+ # Run Codegen pass
+ with tempfile.TemporaryDirectory() as work_dir:
+ with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0):
+ new_mod = tvm.transform.Sequential(
+ [
+ relax.transform.FuseOpsByPattern(patterns),
+ relax.transform.MergeCompositeFunctions(),
+ relax.transform.RunCodegen(),
+ relax.transform.LegalizeOps(),
+ relax.transform.MetaScheduleTuneIRMod(
+ params={}, work_dir=work_dir, max_trials_global=8
+ ),
+ relax.transform.MetaScheduleApplyDatabase(work_dir),
+ ]
+ )(mod)
+ assert relax.analysis.well_formed(new_mod)
+ with transform.PassContext(opt_level=0):
+ ex0 = relax.vm.build(new_mod, target, params={})
+
+ # Sanity check for the correctness and rountrip
+ check_roundtrip(ex0, dev, inputs, expected)
+
+
[email protected]_module
+class Conv2dx2:
+ @R.function
+ def main(
+ data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+ with R.dataflow():
+ lv: R.Tensor((16, 32, 32, 16), dtype="float16") =
fused_relax_nn_conv2d_tensorrt(
+ data, weight1
+ )
+ gv: R.Tensor((16, 32, 32, 16), dtype="float16") =
fused_relax_nn_conv2d_tensorrt(
+ lv, weight2
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_nn_conv2d_tensorrt(
+ data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+ R.func_attr({"Codegen": "tensorrt", "global_symbol":
"fused_relax_nn_conv2d_tensorrt"})
+
+ @R.function
+ def gv(
+ data_1: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+ R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1})
+ with R.dataflow():
+ gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") =
R.nn.conv2d(
+ data_1,
+ weight1_1,
+ padding=[1, 1, 1, 1],
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ )
+ R.output(gv_1)
+ return gv_1
+
+ gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1)
+ return gv1
+
+
[email protected]_module
+class Conv2dx2_after:
+ @R.function
+ def main(
+ data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+ with R.dataflow():
+ lv = R.call_tir(
+ "fused_relax_nn_conv2d_tensorrt",
+ (data, weight1),
+ out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
+ )
+ gv = R.call_tir(
+ "fused_relax_nn_conv2d_tensorrt",
+ (lv, weight2),
+ out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
+ )
+ R.output(gv)
+ return gv
+
+
+def test_multiple_calls_same_extern():
+ mod = relax.transform.RunCodegen()(Conv2dx2)
+ tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])
+
+
+# TODO(@sunggg): test with more complex patterns (e.g., multiple annots,
mixed codegens, different ops, const binding)
+
+if __name__ == "__main__":
+ pytest.main([__file__])