This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 139921e709 [Unity][BYOC][Pass] RunCodegen and TensorRT  (#14078)
139921e709 is described below

commit 139921e709ac421806061a9f11469062810ea688
Author: Sunghyun Park <[email protected]>
AuthorDate: Tue Feb 21 19:14:52 2023 -0800

    [Unity][BYOC][Pass] RunCodegen and TensorRT  (#14078)
    
    This PR introduces the fundamental workflow for BYOC and integrate TensorRT 
as a demonstration.
---
 cmake/modules/contrib/TensorRT.cmake               |   2 +-
 include/tvm/ir/module.h                            |   6 +
 include/tvm/relax/transform.h                      | 110 ++++--
 python/tvm/ir/module.py                            |  35 ++
 python/tvm/relax/transform/transform.py            |  23 ++
 src/ir/module.cc                                   |  12 +
 .../backend/contrib/codegen_json/codegen_json.h    | 419 +++++++++++++++++++++
 src/relax/backend/contrib/tensorrt/codegen.cc      | 267 +++++++++++++
 src/relax/backend/contrib/utils.h                  | 127 +++++++
 src/relax/transform/run_codegen.cc                 | 190 ++++++++++
 tests/python/relax/test_codegen_tensorrt.py        | 124 ++++++
 tests/python/relax/test_transform_codegen_pass.py  | 260 +++++++++++++
 12 files changed, 1553 insertions(+), 22 deletions(-)

diff --git a/cmake/modules/contrib/TensorRT.cmake 
b/cmake/modules/contrib/TensorRT.cmake
index 696108b501..a749b6e80f 100644
--- a/cmake/modules/contrib/TensorRT.cmake
+++ b/cmake/modules/contrib/TensorRT.cmake
@@ -23,7 +23,7 @@ include (FindPackageHandleStandardArgs)
 
 if(USE_TENSORRT_CODEGEN)
     message(STATUS "Build with TensorRT codegen")
-    tvm_file_glob(GLOB COMPILER_TENSORRT_SRCS 
src/relay/backend/contrib/tensorrt/*.cc)
+    tvm_file_glob(GLOB COMPILER_TENSORRT_SRCS 
src/relay/backend/contrib/tensorrt/*.cc src/relax/backend/contrib/tensorrt/*.cc)
     set_source_files_properties(${COMPILER_TENSORRT_SRCS} PROPERTIES 
COMPILE_FLAGS "-Wno-deprecated-declarations")
     tvm_file_glob(GLOB RUNTIME_TENSORRT_SRCS 
src/runtime/contrib/tensorrt/tensorrt_runtime.cc)
     set_source_files_properties(${RUNTIME_TENSORRT_SRCS} PROPERTIES 
COMPILE_FLAGS "-Wno-deprecated-declarations")
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index fdb44b1188..538ff64ca3 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -115,6 +115,12 @@ class IRModuleNode : public Object {
     return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
   }
 
+  /*!
+   * \brief Get the metadata attributes.
+   * \returns The additional meta-data attributes
+   */
+  DictAttrs GetAttrs() const { return attrs; }
+
   /*!
    * \brief Check whether the module has an non-zero integer attr.
    *
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 7d9f3d64b0..7d6c93bcde 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -25,6 +25,7 @@
 #define TVM_RELAX_TRANSFORM_H_
 
 #include <tvm/ir/transform.h>
+#include <tvm/relax/dataflow_pattern.h>
 #include <tvm/relax/expr.h>
 
 namespace tvm {
@@ -67,6 +68,13 @@ TVM_DLL Pass CreateDataflowBlockPass(
     const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)>& pass_func,
     int opt_level, String name, tvm::Array<String> required, bool traceable = 
false);
 
+/*!
+ * \brief Perform lambda lifting to lift functions from nested into global.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass LambdaLift();
+
 /*!
  * \brief Transform all dataflow structure to non-dataflow version.
  *
@@ -105,27 +113,20 @@ TVM_DLL Pass RewriteDataflowReshape();
 TVM_DLL Pass StaticPlanBlockMemory();
 
 /*!
- * \brief Bind params of function of the module to constant tensors.
- *
- * \param func_name The name of the function to bind parameters.
- * \param params The parameters to bind.
+ * \brief Attach global_symbol to Relax functions and TIR Primfuncs for 
codegen.
  *
  * \return The Pass.
  */
-TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> 
params);
+TVM_DLL Pass AttachGlobalSymbol();
 
 /*!
- * \brief Fold constant expressions.
- *
- * \return The Pass.
- */
-TVM_DLL Pass FoldConstant();
-/*!
- * \brief Attach global_symbol to Relax functions and TIR Primfuncs for 
codegen.
+ * \brief Transform Relax IR to normal form: transform AST to A-normal form, 
and fill the
+ * checked_type_ and shape_ of expressions.
  *
  * \return The Pass.
  */
-TVM_DLL Pass AttachGlobalSymbol();
+TVM_DLL Pass Normalize();
+
 /*!
  * \brief Bind params of function of the module to constant tensors.
  *
@@ -143,14 +144,6 @@ TVM_DLL Pass BindParams(String func_name, Map<String, 
runtime::NDArray> params);
  */
 TVM_DLL Pass FoldConstant();
 
-/*!
- * \brief Transform Relax IR to normal form: transform AST to A-normal form, 
and fill the
- * checked_type_ and shape_ of expressions.
- *
- * \return The Pass.
- */
-TVM_DLL Pass Normalize();
-
 /*!
  * \brief Legalize high-level operator calls in Relax functions to call_tir
  * with corresponding low-level TIR PrimFuncs.
@@ -190,6 +183,81 @@ TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> 
cmap);
  */
 TVM_DLL Pass LiftTransformParams();
 
+/*!
+ * \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
+ * \note It is an auto-detect pass for "unscheduled prim_funcs", the 
op_pattern will be
+ *       "opaque" of we can't detect it. Users can manually annotate the attr 
`op_pattern`
+ *       to prim_func.
+ * \return The Pass.
+ */
+TVM_DLL Pass AnnotateTIROpPattern();
+
+/*!
+ * \brief This pass groups bindings in a dataflow block of Relax functions and 
generates a new
+ * grouped Relax function for each group, according to the fusion algorithm 
described in the pass
+ * implementation. By grouping bindings into new Relax functions, we 
substitute the bindings in the
+ * function being manipulated into function calls to the new grouped function.
+ *
+ * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each 
grouped function.
+ * \param fuse_opt_level The level of fuse optimization.
+ *        -1 indicates that the level will be inferred from pass context.
+ * \return The Pass.
+ */
+TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
+
+/*!
+ * \brief Apply pattern matching to each function in the given module, and 
group matched
+ * expressions into a new function. The end result is similar to FuseOps, but 
fusion is driven
+ * completely by the provided patterns.
+ *
+ * \param pattern_names The name of each pattern. It becomes the value of the 
kComposite attribute
+ * of a fused function after successful matching.
+ * \param patterns The patterns to detect. The order of the patterns 
determines the order
+ * of priority in which they are matched. Higher-priority patterns should come 
earlier in the list.
+ * \param annotate_codegen If true, wrap each created composite function with 
another function,
+ * whose body consists only of a call to the composite function, and annotate 
the outer function
+ * with kCodegen and kGlobalSymbol attributes. The kCodegen attribute is set 
as the prefix of the
+ * corresponding pattern name. For example, "dnnl" if the pattern name is 
"dnnl.conv2d_relu".
+ * This must be True if the created composite functions are intended to be 
offloaded to
+ * an external backend without using the MergeCompositeFunctions pass.
+ * \return The Pass.
+ */
+TVM_DLL Pass FuseOpsByPattern(const tvm::Array<runtime::String>& pattern_names,
+                              const tvm::Array<DFPattern>& patterns, bool 
annotate_codegen = false);
+
+/*!
+ * \brief Group one or multiple composite functions created by 
FuseOpsByPattern into a new
+ *  function. The new function will be annotated with kCodegen and 
GlobalSymbol attributes,
+ *  and it is intented to be offloaded to an external backend.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass MergeCompositeFunctions();
+
+/*!
+ * \brief Fuse relax sub-function into a larger TIR function if possible.
+    this pass works together with FuseOps to perform operator fusion.
+
+ * \return The Pass.
+ */
+TVM_DLL Pass FuseTIR();
+
+/*!
+ * \brief Remove unused global relax functions in an IRModule.
+ * \param entry_functions list of entry functions
+ * \return The Pass.
+ */
+TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
+
+/*!
+ * \brief Run codegen.
+ * \param target_options pairs of target name and compilation options
+ * \param entry_functions list of entry functions
+ * \return The Pass.
+ */
+TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>> 
target_options,
+                        Array<runtime::String> entry_functions);
+
 }  // namespace transform
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 3daffb2640..6a151d5a89 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -15,13 +15,18 @@
 # specific language governing permissions and limitations
 # under the License.
 """IRModule that holds the functions and type definitions."""
+from __future__ import annotations
+
+from typing import Dict, Union
 import tvm._ffi
 from tvm._ffi.base import string_types
 from tvm.runtime import Scriptable
+from tvm.runtime.object import Object
 
 from . import _ffi_api
 from . import expr as _expr
 from . import type as _ty
+from .attrs import DictAttrs
 from .base import Node
 
 
@@ -286,6 +291,36 @@ class IRModule(Node, Scriptable):
 
         return _ffi_api.Module_WithAttr(self, attr_key, attr_value)
 
+    def without_attr(self, attr_key: str) -> "IRModule":
+        """Copy the IRModule and remove an attribute key and its associated 
value.
+        Parameters
+        ----------
+        attr_key : str
+            The attribute key.
+        Returns
+        -------
+        mod : IRModule
+            A new copy of the IRModule without the attribute
+        """
+
+        return _ffi_api.Module_WithoutAttr(self, attr_key)
+
+    def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> 
"IRModule":
+        """Copy the IRModule and add the given attribute map to it.
+        Parameters
+        ----------
+        attr_map: Union[DictAttrs, Dict[str, Object]]
+            The attribute map
+        Returns
+        -------
+        mod : IRModule
+            A new copy of the IRModule with the attribute
+        """
+        if isinstance(attr_map, tvm.ir.DictAttrs):
+            attr_map = attr_map._dict()
+
+        return _ffi_api.Module_WithAttrs(self, attr_map)
+
     def astext(self, show_meta_data=True, annotate=None):
         """Get the text format of the expression.
 
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 590059739c..9fb2458dc0 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -188,6 +188,29 @@ def RemoveUnusedFunctions(entry_functions: 
Optional[List[str]] = None) -> tvm.ir
     return _ffi_api.RemoveUnusedFunctions(entry_functions)  # type: ignore
 
 
+def RunCodegen(
+    target_options: Optional[dict] = None,
+    entry_functions: Optional[List[str]] = None,
+) -> tvm.ir.transform.Pass:
+    """Produce the runtime::Module with an annotated codegen and global symbol.
+
+    Parameters
+    ----------
+    target_options: Optional[dict]
+        Pairs of a target name and compilation options
+    entry_functions: Optional[List[str]]
+        The set of entry functions to start from.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass to remove unused functions.
+    """
+    if entry_functions is None:
+        entry_functions = ["main"]
+    return _ffi_api.RunCodegen(target_options, entry_functions)  # type: ignore
+
+
 def FoldConstant() -> tvm.ir.transform.Pass:
     """Fold constant expressions.
 
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 4a09bdaaf7..8f23f19d35 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -431,11 +431,23 @@ 
TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, S
   mod->ImportFromStd(path);
 });
 
+TVM_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) -> 
ObjectRef {
+  return mod->GetAttrs();
+});
+
 TVM_REGISTER_GLOBAL("ir.Module_WithAttr")
     .set_body_typed([](IRModule mod, String key, ObjectRef value) -> IRModule {
       return WithAttr(mod, key, value);
     });
 
+TVM_REGISTER_GLOBAL("ir.Module_WithoutAttr")
+    .set_body_typed([](IRModule mod, String key) -> IRModule { return 
WithoutAttr(mod, key); });
+
+TVM_REGISTER_GLOBAL("ir.Module_WithAttrs")
+    .set_body_typed([](IRModule mod, Map<String, ObjectRef> attr_map) -> 
IRModule {
+      return WithAttrs(mod, attr_map);
+    });
+
 TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, 
String key) -> ObjectRef {
   return mod->GetAttr<ObjectRef>(key);
 });
diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h 
b/src/relax/backend/contrib/codegen_json/codegen_json.h
new file mode 100644
index 0000000000..2197998707
--- /dev/null
+++ b/src/relax/backend/contrib/codegen_json/codegen_json.h
@@ -0,0 +1,419 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/backend/contrib/codegen_json/codegen_json.h
+ * \brief Utilities for json codegen and runtime
+ */
+#ifndef TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_
+#define TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_
+
+#include <dmlc/any.h>
+#include <dmlc/json.h>
+#include <tvm/node/reflection.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/tir/op.h>
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "../../../../runtime/contrib/json/json_node.h"
+#include "../../../../runtime/contrib/json/json_runtime.h"
+#include "../../../transform/utils.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relax {
+namespace backend {
+namespace contrib {
+
+using namespace tvm::runtime::json;
+
+using ShapeVector = std::vector<std::vector<int64_t>>;
+using TypeVector = std::vector<std::string>;
+using JSONGraphObjectPtr = std::shared_ptr<JSONGraphNode>;
+
+/*!
+ * \brief Helper class to extract all attributes of a certain op and save them
+ * into text format.
+ */
+class OpAttrExtractor : public AttrVisitor {
+ public:
+  explicit OpAttrExtractor(JSONGraphObjectPtr node) : node_(node) {}
+
+  template <typename T = double, typename = 
std::enable_if_t<std::is_floating_point<T>::value>>
+  std::string Fp2String(const T value) {
+    std::ostringstream out;
+    out.precision(std::numeric_limits<T>::max_digits10);
+    out << value;
+    return out.str();
+  }
+
+  void SetNodeAttr(const char* key, const std::vector<std::string>& value) {
+    std::vector<dmlc::any> attr;
+    attr.emplace_back(value);
+    node_->SetAttr(key, attr);
+  }
+
+  void Visit(const char* key, double* value) final { SetNodeAttr(key, 
{Fp2String(*value)}); }
+
+  void Visit(const char* key, int64_t* value) final { SetNodeAttr(key, 
{std::to_string(*value)}); }
+
+  void Visit(const char* key, uint64_t* value) final { SetNodeAttr(key, 
{std::to_string(*value)}); }
+
+  void Visit(const char* key, int* value) final { SetNodeAttr(key, 
{std::to_string(*value)}); }
+
+  void Visit(const char* key, bool* value) final { SetNodeAttr(key, 
{std::to_string(*value)}); }
+
+  void Visit(const char* key, std::string* value) final { SetNodeAttr(key, 
{*value}); }
+
+  void Visit(const char* key, DataType* value) final {
+    if (!value->is_void()) {
+      SetNodeAttr(key, {runtime::DLDataType2String(*value)});
+    } else {
+      SetNodeAttr(key, {""});
+    }
+  }
+
+  void Visit(const char* key, runtime::ObjectRef* value) final {
+    if (const auto* an = (*value).as<ArrayNode>()) {
+      std::vector<std::string> attr;
+      for (size_t i = 0; i < an->size(); ++i) {
+        if (const auto* im = (*an)[i].as<IntImmNode>()) {
+          attr.push_back(std::to_string(im->value));
+        } else if (const auto* fm = (*an)[i].as<FloatImmNode>()) {
+          attr.push_back(Fp2String(fm->value));
+        } else if (const auto* str = (*an)[i].as<StringObj>()) {
+          String s = GetRef<String>(str);
+          attr.push_back(s);
+        } else {
+          LOG(FATAL) << "Not supported type: " << (*an)[i]->GetTypeKey();
+        }
+      }
+      SetNodeAttr(key, attr);
+    } else if (!(*value).defined()) {  // Skip NullValue
+      SetNodeAttr(key, std::vector<std::string>{""});
+    } else if (const auto* im = (*value).as<IntImmNode>()) {
+      SetNodeAttr(key, std::vector<std::string>{std::to_string(im->value)});
+    } else if (const auto* fm = (*value).as<FloatImmNode>()) {
+      SetNodeAttr(key, std::vector<std::string>{Fp2String(fm->value)});
+    } else if (const auto* str = (*value).as<StringObj>()) {
+      String s = GetRef<String>(str);
+      SetNodeAttr(key, std::vector<std::string>{s});
+    } else {
+      LOG(FATAL) << "Not yet supported type: " << (*value)->GetTypeKey() << ": 
" << *value;
+    }
+  }
+
+  void Visit(const char* key, runtime::NDArray* value) final {
+    LOG(FATAL) << "NDArray is not allowed in op attribute";
+  }
+
+  void Visit(const char* key, void** value) final {
+    LOG(FATAL) << "void pointer is not allowed in op attribute";
+  }
+
+  void Extract(Object* node) {
+    if (node) {
+      reflection_->VisitAttrs(node, this);
+    }
+  }
+
+ private:
+  JSONGraphObjectPtr node_;
+  ReflectionVTable* reflection_ = ReflectionVTable::Global();
+};
+
+using NodeEntries = std::vector<JSONGraphNodeEntry>;
+
+/*! \brief Serialize a Relax expression to JSON. */
+class JSONSerializer : public relax::MemoizedExprTranslator<NodeEntries> {
+ public:
+  using MemoizedExprTranslator<NodeEntries>::VisitExpr_;
+  using MemoizedExprTranslator<NodeEntries>::VisitBinding_;
+
+  /*!
+   * \brief Constructor
+   * \param constant_names The names of all constants in the original module.
+   */
+  explicit JSONSerializer(const Map<Constant, String>& constant_names)
+      : constant_names_(constant_names) {}
+
+  void serialize(Function func) {
+    // First we convert all the parameters into input nodes.
+    for (const auto& param : func->params) {
+      auto node_ptr = std::make_shared<JSONGraphNode>(param->name_hint(), 
"input" /* op_type_ */);
+      memo_[param] = AddNode(node_ptr, param);
+    }
+    heads_ = VisitExpr(func->body);
+  }
+
+  /*!\brief Return the required constants. */
+  Array<String> GetConstantNames() const { return constants_used_; }
+
+  /*!\brief Return the generated json. */
+  std::string GetJSON() {
+    std::ostringstream os;
+    dmlc::JSONWriter writer(&os);
+    Save(&writer);
+    return os.str();
+  }
+
+ protected:
+  /*!
+   * \brief Add a node to graph.
+   *
+   * \param node A graph node. It is a shared pointer. Some attributes of it
+   *        will be added, i.e. shape and type. These attributes are attached 
to
+   *        the JSON graph in the end.
+   * \param expr The relax expression.
+   * \return A list of graph entry nodes. It the relax expr is a tuple type, we
+   *         will flatten it.
+   */
+  NodeEntries AddNode(JSONGraphObjectPtr node, const Expr& expr) {
+    auto struct_info = GetStructInfo(expr);
+    auto node_id = nodes_.size();
+    nodes_.push_back(node);
+    NodeEntries ret;
+    ShapeVector shape;
+    TypeVector dtype;
+
+    // Flatten tuple node.
+    if (const auto* tuple_sinfo = struct_info.as<TupleStructInfoNode>()) {
+      for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
+        const auto* tensor_sinfo = 
tuple_sinfo->fields[i].as<TensorStructInfoNode>();
+        ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: ."
+                             << tuple_sinfo->fields[i]->GetTypeKey();
+        ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined.";
+        ShapeExpr output_shape = 
Downcast<ShapeExpr>(tensor_sinfo->shape.value());
+        ret.push_back(JSONGraphNodeEntry(node_id, i));
+        shape.emplace_back(GetIntShape(output_shape->values));
+        dtype.emplace_back(DType2String(tensor_sinfo->dtype));
+      }
+      node->SetNumOutput(tuple_sinfo->fields.size());
+    } else {
+      const auto* tensor_sinfo = struct_info.as<TensorStructInfoNode>();
+      ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: "
+                           << struct_info->GetTypeKey();
+      ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined.";
+      ShapeExpr output_shape = 
Downcast<ShapeExpr>(tensor_sinfo->shape.value());
+
+      shape.emplace_back(GetIntShape(output_shape->values));
+      dtype.emplace_back(DType2String(tensor_sinfo->dtype));
+      ret.push_back(JSONGraphNodeEntry(node_id, 0));
+    }
+    std::vector<dmlc::any> shape_attrs;
+    shape_attrs.emplace_back(shape);
+    node->SetAttr("shape", shape_attrs);
+
+    std::vector<dmlc::any> type_attrs;
+    type_attrs.emplace_back(dtype);
+    node->SetAttr("dtype", type_attrs);
+    return ret;
+  }
+
+  void SetCallNodeAttribute(JSONGraphObjectPtr node, const CallNode* cn) {
+    if (cn->op.as<OpNode>()) {
+      OpAttrExtractor extractor(node);
+      const Object* call_attr = cn->attrs.get();
+      extractor.Extract(const_cast<Object*>(call_attr));
+    } else if (const auto* fn = cn->op.as<FunctionNode>()) {
+      ICHECK(false);
+      auto pattern = fn->GetAttr<String>(attr::kPartitionedFromPattern);
+      ICHECK(pattern.defined());
+      std::vector<std::string> values;
+      values.push_back(pattern.value());
+      std::vector<dmlc::any> attr;
+      attr.emplace_back(values);
+      node->SetAttr("PartitionedFromPattern", attr);
+    }
+  }
+
+  NodeEntries VisitBinding_(const MatchCastNode* binding) {
+    LOG(FATAL) << "JSON runtime currently doesn't match cast\n";
+    return {};
+  }
+
+  NodeEntries VisitBinding(const Binding& binding) {
+    NodeEntries nodes;
+    if (const auto* node = binding.as<VarBindingNode>()) {
+      auto from_b = VisitBinding_(node);
+      nodes.insert(nodes.end(), from_b.begin(), from_b.end());
+    } else if (const auto* node = binding.as<MatchCastNode>()) {
+      auto from_b = VisitBinding_(node);
+      nodes.insert(nodes.end(), from_b.begin(), from_b.end());
+    } else {
+      LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey();
+    }
+    return nodes;
+  }
+
+  NodeEntries VisitBindingBlock(const BindingBlock& block) {
+    NodeEntries nodes;
+    if (const auto* node = block.as<DataflowBlockNode>()) {
+      auto from_bb = VisitBindingBlock_(node);
+      nodes.insert(nodes.end(), from_bb.begin(), from_bb.end());
+    } else if (const auto* node = block.as<BindingBlockNode>()) {
+      auto from_bb = VisitBindingBlock_(node);
+      nodes.insert(nodes.end(), from_bb.begin(), from_bb.end());
+    } else {
+      LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey();
+    }
+    return nodes;
+  }
+
+  NodeEntries VisitBindingBlock_(const BindingBlockNode* block) {
+    NodeEntries nodes;
+    for (Binding binding : block->bindings) {
+      auto from_b = VisitBinding(binding);
+      nodes.insert(nodes.end(), from_b.begin(), from_b.end());
+    }
+    return nodes;
+  }
+
+  NodeEntries VisitBindingBlock_(const DataflowBlockNode* block) {
+    NodeEntries nodes;
+    for (Binding binding : block->bindings) {
+      auto from_b = VisitBinding(binding);
+      nodes.insert(nodes.end(), from_b.begin(), from_b.end());
+    }
+    return nodes;
+  }
+
+  NodeEntries VisitExpr_(const SeqExprNode* op) {
+    NodeEntries nodes;
+    for (BindingBlock block : op->blocks) {
+      VisitBindingBlock(block);
+    }
+    auto from_body = VisitExpr(op->body);
+    nodes.insert(nodes.end(), from_body.begin(), from_body.end());
+    return nodes;
+  }
+
+  NodeEntries VisitExprDefault_(const Object* op) {
+    LOG(FATAL) << "JSON runtime currently doesn't support " << 
op->GetTypeKey();
+    return {};
+  }
+
+  NodeEntries VisitExpr_(const ConstantNode* cn) {
+    auto name = constant_names_.find(GetRef<Constant>(cn));
+    ICHECK(name != constant_names_.end())
+        << "Cannot find the name of the constant: " << GetRef<Constant>(cn);
+    constants_used_.push_back((*name).second);
+    auto node = std::make_shared<JSONGraphNode>((*name).second, "const" /* 
op_type_ */);
+    return AddNode(node, GetRef<Expr>(cn));
+  }
+
+  NodeEntries VisitExpr_(const TupleNode* tn) {
+    NodeEntries fields;
+    for (const auto& field : tn->fields) {
+      auto ref = VisitExpr(field);
+      fields.insert(fields.end(), ref.begin(), ref.end());
+    }
+    return fields;
+  }
+
+  NodeEntries VisitExpr_(const CallNode* cn) {
+    Expr expr = GetRef<Expr>(cn);
+    std::string name;
+    if (const auto* op_node = cn->op.as<OpNode>()) {
+      name = op_node->name;
+    } else if (const auto* fn = cn->op.as<FunctionNode>()) {
+      auto comp = fn->GetAttr<String>(attr::kComposite);
+      ICHECK(comp.defined()) << "JSON runtime only supports composite 
functions.";
+      name = comp.value();
+    } else {
+      LOG(FATAL) << "JSON runtime does not support calls to " << 
cn->op->GetTypeKey();
+    }
+
+    // TODO(@sunggg): Revisit when we have op naming convention.
+    // Currently, simply remove "relax." prefix to make it work.
+    name = std::string("tensorrt.") + name.substr(6);
+
+    NodeEntries inputs;
+    for (const auto& arg : cn->args) {
+      auto res = VisitExpr(arg);
+      inputs.insert(inputs.end(), res.begin(), res.end());
+    }
+    auto node = std::make_shared<JSONGraphNode>(name,     /* name_ */
+                                                "kernel", /* op_type_ */
+                                                inputs, 1 /* num_outputs_ */);
+    SetCallNodeAttribute(node, cn);
+    return AddNode(node, GetRef<Expr>(cn));
+  }
+
+  NodeEntries VisitExpr_(const TupleGetItemNode* gtn) {
+    auto vtuple = VisitExpr(gtn->tuple);
+    return {vtuple[gtn->index]};
+  }
+
+  NodeEntries VisitExpr_(const FunctionNode* fn) {
+    ICHECK(fn->GetAttr<String>(attr::kComposite).defined())
+        << "JSON runtime only supports composite functions";
+
+    // FunctionNode should be handled by the caller.
+    return {};
+  }
+
+  /*!
+   * \brief Save to JSON graph
+   *
+   * \param writer A json writer
+   */
+  void Save(dmlc::JSONWriter* writer) {
+    std::vector<size_t> arg_nodes;
+    for (size_t i = 0; i < nodes_.size(); ++i) {
+      auto node = nodes_[i];
+      if (node->IsLeaf()) {
+        arg_nodes.push_back(i);
+      }
+    }
+    size_t num_entry = 0;
+    std::vector<size_t> node_row_ptr{0};
+    for (auto node : nodes_) {
+      num_entry += node->GetNumOutput();
+      node_row_ptr.push_back(num_entry);
+    }
+    writer->BeginObject();
+    writer->WriteObjectKeyValue("nodes", nodes_);
+    writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
+    writer->WriteObjectKeyValue("heads", heads_);
+    writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
+    writer->EndObject();
+  }
+
+ private:
+  /*! \brief JSON graph nodes. */
+  std::vector<JSONGraphObjectPtr> nodes_;
+  /*! \brief Output of the JSON graph. */
+  NodeEntries heads_;
+  /*! \brief The list of required constants, ordered. */
+  Array<String> constants_used_;
+  /*! \brief The names of all constants in the original module. */
+  const Map<Constant, String>& constant_names_;
+};
+
+}  // namespace contrib
+}  // namespace backend
+}  // namespace relax
+}  // namespace tvm
+#endif  // TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_
diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc 
b/src/relax/backend/contrib/tensorrt/codegen.cc
new file mode 100644
index 0000000000..5ce6bf5e7d
--- /dev/null
+++ b/src/relax/backend/contrib/tensorrt/codegen.cc
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/backend/contrib/tensorrt/codegen.cc
+ * \brief Implementation of the TensorRT JSON serializer.
+ */
+#include <tvm/ir/module.h>
+// TODO(sunggg): add operator attribute when it's ready
+// #include <tvm/relax/attrs/nn.h>
+#include <tvm/relax/type.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "../../../transform/utils.h"
+#include "../codegen_json/codegen_json.h"
+#include "../utils.h"
+
+#if TVM_GRAPH_EXECUTOR_TENSORRT
+#include "NvInfer.h"
+#endif
+
+namespace tvm {
+namespace relax {
+namespace contrib {
+
+/*! \brief Attributes to store the compiler options for TensorRT. */
+struct TensorRTCompilerConfigNode : public 
tvm::AttrsNode<TensorRTCompilerConfigNode> {
+  Array<Integer> tensorrt_version;
+  bool use_implicit_batch;
+  size_t max_workspace_size;
+  bool remove_no_mac_subgraphs;
+  bool use_fp16;
+  bool use_uint8;
+
+  TVM_DECLARE_ATTRS(TensorRTCompilerConfigNode, 
"relax.ext.attrs.TensorRTCompilerConfigNode") {
+    TVM_ATTR_FIELD(tensorrt_version)
+        .describe("TensorRT version as (major, minor, patch).")
+        .set_default(Array<Integer>({6, 0, 1}));
+    TVM_ATTR_FIELD(use_implicit_batch).set_default(true);
+    TVM_ATTR_FIELD(max_workspace_size).set_default(size_t(1) << 30);
+    TVM_ATTR_FIELD(remove_no_mac_subgraphs).set_default(false);
+    TVM_ATTR_FIELD(use_fp16).set_default(false);
+    TVM_ATTR_FIELD(use_uint8).set_default(false);
+  }
+};
+
+class TensorRTCompilerConfig : public Attrs {
+ public:
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorRTCompilerConfig, Attrs,
+                                            TensorRTCompilerConfigNode);
+};
+
+TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("relax.ext.tensorrt.options", 
TensorRTCompilerConfig);
+
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr;
+using OpAttrExtractor = backend::contrib::OpAttrExtractor;
+using JSONSerializer = backend::contrib::JSONSerializer;
+
+class TensorRTJSONSerializer;
+
+/*!
+ * \brief Collect the constants and attributes from all operator calls in the 
body
+ * of a "Composite" function.
+ */
+class CollectFromCompositeFunctionBody : public ExprVisitor {
+ public:
+  explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer)
+      : serializer_(serializer), node_(std::make_shared<JSONGraphNode>()) {}
+
+  void VisitExpr_(const ConstantNode* constant_node) final;
+  void VisitExpr_(const CallNode* call_node) final;
+
+  void SetGenericAttributes(const CallNode* call_node) {
+    OpAttrExtractor extractor(node_);
+    const Object* attr_obj = call_node->attrs.get();
+    extractor.Extract(const_cast<Object*>(attr_obj));
+  }
+
+  TensorRTJSONSerializer* serializer_;
+  /*! \brief Accumulated translated arguments. */
+  std::vector<JSONGraphNodeEntry> args_;
+  /*!
+   * \brief Temporary node into which we'll accumulate attributes. Ideally 
this would be the
+   * final JSONGraphNode however we don't yet know how many inputs that will 
have.
+   */
+  JSONGraphObjectPtr node_;
+};
+
+/*!
+ * \brief Generates an TensorRTModule from a relax expression by serializing 
the expression to a
+ * json representation. TensorRT is not required here because use of TensorRT 
APIs is deferred until
+ * runtime.
+ */
+class TensorRTJSONSerializer : public JSONSerializer {
+ public:
+  explicit TensorRTJSONSerializer(Map<Constant, String> constant_names, 
Map<Var, Expr> bindings)
+      : JSONSerializer(constant_names), bindings_(bindings) {}
+
+  using JSONSerializer::VisitExpr_;
+
+  std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* call_node) final {
+    // The call must be to an inline "Composite" function
+    const auto* fn_var = call_node->op.as<VarNode>();
+    ICHECK(fn_var);
+    const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
+
+    auto opt_composite = fn->GetAttr<String>(attr::kComposite);
+    ICHECK(opt_composite.defined());
+    std::string name = opt_composite.value();
+
+    // Collect the constants and attributes of all operator calls inside the 
composite body.
+    CollectFromCompositeFunctionBody collector(this);
+    collector.VisitExpr(fn->body);
+
+    // Capture the args to the "Composite" function as inputs for this node.
+    std::vector<JSONGraphNodeEntry> inputs;
+    for (const auto& arg : call_node->args) {
+      auto res = VisitExpr(arg);
+      inputs.insert(inputs.end(), res.begin(), res.end());
+    }
+
+    // Capture constants from the composite function body as additional inputs 
for this node.
+    for (const auto& node : collector.args_) {
+      inputs.emplace_back(node);
+    }
+
+    // Create the final node.
+    auto node = std::make_shared<JSONGraphNode>(name,
+                                                /*op_type=*/"kernel", inputs,
+                                                /*num_output=*/1);
+
+    // Transfer attributes from the collector's node to the final node.
+    node->CaptureAttrs(*collector.node_);
+
+    // Capture global settings on the JSON node.
+    SaveGlobalAttributes(node);
+
+    VLOG(1) << name << " has " << node->GetInputs().size() << " inputs";
+
+    return AddNode(node, GetRef<Expr>(call_node));
+  }
+
+  static void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
+    auto ctx = transform::PassContext::Current();
+    auto cfg = 
ctx->GetConfig<TensorRTCompilerConfig>("relax.ext.tensorrt.options");
+    if (!cfg.defined()) {
+      cfg = AttrsWithDefaultValues<TensorRTCompilerConfig>();
+    }
+    ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3);
+    std::vector<std::string> tensorrt_version = {
+        std::to_string(cfg.value()->tensorrt_version[0].IntValue()),
+        std::to_string(cfg.value()->tensorrt_version[1].IntValue()),
+        std::to_string(cfg.value()->tensorrt_version[2].IntValue())};
+    std::vector<std::string> use_implicit_batch = 
{std::to_string(cfg.value()->use_implicit_batch)};
+    std::vector<std::string> max_workspace_size = 
{std::to_string(cfg.value()->max_workspace_size)};
+    std::vector<std::string> use_fp16 = 
{std::to_string(cfg.value()->use_fp16)};
+    std::vector<std::string> use_uint8 = 
{std::to_string(cfg.value()->use_uint8)};
+    std::vector<dmlc::any> tensorrt_version_attr, use_implicit_batch_attr, 
max_workspace_size_attr,
+        use_fp16_attr, use_uint8_attr;
+    tensorrt_version_attr.emplace_back(tensorrt_version);
+    use_implicit_batch_attr.emplace_back(use_implicit_batch);
+    max_workspace_size_attr.emplace_back(max_workspace_size);
+    use_fp16_attr.emplace_back(use_fp16);
+    use_uint8_attr.emplace_back(use_uint8);
+    node->SetAttr("tensorrt_version", tensorrt_version_attr);
+    node->SetAttr("use_implicit_batch", use_implicit_batch_attr);
+    node->SetAttr("max_workspace_size", max_workspace_size_attr);
+    node->SetAttr("use_fp16", use_fp16_attr);
+    node->SetAttr("use_uint8", use_uint8_attr);
+  }
+
+ private:
+  /*! \brief The bindings to look up composite functions. */
+  Map<Var, Expr> bindings_;
+};
+
+void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* 
constant_node) {
+  for (const auto& entry : 
serializer_->VisitExpr(GetRef<Constant>(constant_node))) {
+    args_.emplace_back(entry);
+  }
+}
+
+void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
+  SetGenericAttributes(call_node);
+  ExprVisitor::VisitExpr_(call_node);
+}
+
+/*!
+ * \brief Create runtime modules for TensorRT.
+ * \param functions The extern functions to be compiled via TensorRT
+ * \return Runtime modules.
+ */
+Array<runtime::Module> TensorRTCompiler(Array<Function> functions,
+                                        Map<String, ObjectRef> /*unused*/,
+                                        Map<Constant, String> constant_names) {
+  Array<runtime::Module> compiled_functions;
+  for (const auto& func : functions) {
+    VLOG(1) << "TensorRT partition:" << std::endl << func;
+    TensorRTJSONSerializer serializer(constant_names, AnalyzeVar2Value(func));
+    serializer.serialize(func);
+    std::string graph_json = serializer.GetJSON();
+    VLOG(1) << "TensorRT JSON:" << std::endl << graph_json;
+    auto constant_names = serializer.GetConstantNames();
+    const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
+    ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create 
function.";
+    std::string func_name = GetExtSymbol(func);
+    VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'";
+    compiled_functions.push_back((*pf)(func_name, graph_json, constant_names));
+  }
+  return compiled_functions;
+}
+
+TVM_REGISTER_GLOBAL("relax.ext.tensorrt").set_body_typed(TensorRTCompiler);
+
+/*!
+ * \brief Check whether TensorRT graph executor is enabled.
+ * \return True if enabled, False if not.
+ */
+inline constexpr bool IsTensorRTRuntimeEnabled() {
+#if TVM_GRAPH_EXECUTOR_TENSORRT
+  return true;
+#else
+  return false;
+#endif  // TVM_GRAPH_EXECUTOR_TENSORRT
+}
+
+/*!
+ * \brief Get TensorRT version that TVM is built against.
+ * \return Array of three integers for major, minor, and patch, or empty array 
if TensorRT graph
+ * runtime is not enabled.
+ */
+Array<Integer> GetTensorRTVersion() {
+#if TVM_GRAPH_EXECUTOR_TENSORRT
+  return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), 
Integer(NV_TENSORRT_PATCH)};
+#else
+  return {};
+#endif  // TVM_GRAPH_EXECUTOR_TENSORRT
+}
+
+TVM_REGISTER_GLOBAL("relax.is_tensorrt_runtime_enabled").set_body_typed(IsTensorRTRuntimeEnabled);
+TVM_REGISTER_GLOBAL("relax.get_tensorrt_version").set_body_typed(GetTensorRTVersion);
+
+}  // namespace contrib
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/backend/contrib/utils.h 
b/src/relax/backend/contrib/utils.h
new file mode 100644
index 0000000000..4190ad66b6
--- /dev/null
+++ b/src/relax/backend/contrib/utils.h
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/backend/contrib/utils.h
+ * \brief Utils function for backend
+ */
+#ifndef TVM_RELAX_BACKEND_CONTRIB_UTILS_H_
+#define TVM_RELAX_BACKEND_CONTRIB_UTILS_H_
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+
+#include <string>
+#include <vector>
+
+#include "../../transform/utils.h"
+
+namespace tvm {
+namespace relax {
+namespace backend {
+
+/*!
+ * \brief Get the Packed Func
+ *
+ * \param func_name
+ * \return const PackedFunc*
+ */
+inline const PackedFunc* GetPackedFunc(const std::string& func_name) {
+  return tvm::runtime::Registry::Get(func_name);
+}
+
+/*!
+ * \brief Extract shape from an IndexExpr array to std::vector<int64_t>
+ *
+ * \param shape The shape in Array
+ * \return The converted shape in std::vector<int64_t>
+ */
+
+inline std::vector<int64_t> GetIntShape(const Array<PrimExpr>& shape) {
+  std::vector<int64_t> ret;
+  for (const auto& dim : shape) {
+    const int64_t* pval = tir::as_const_int(dim);
+    ret.push_back(pval ? *pval : -1);
+  }
+  return ret;
+}
+
+/*!
+ * \brief Convert type to string
+ *
+ * \param typ
+ * \return std::string string format of type
+ */
+inline std::string DType2String(const tvm::DataType dtype) {
+  std::ostringstream os;
+  if (dtype.is_float()) {
+    os << "float";
+  } else if (dtype.is_int()) {
+    os << "int";
+  } else if (dtype.is_uint()) {
+    os << "uint";
+  } else if (dtype.is_bfloat16()) {
+    os << "bfloat";
+  } else if 
((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) {
+    os << "custom["
+       << 
(*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator 
std::string()
+       << "]";
+  } else {
+    LOG(FATAL) << "Unknown type with code " << 
static_cast<unsigned>(dtype.code());
+  }
+  os << dtype.bits();
+  return os.str();
+}
+
+/*!
+ * \brief Check if a call node is calling an op with the given name
+ * \param call The call node whose callee we want to check
+ * \param op_name The name of the op
+ * \return true if the callee op matches with the op name
+ */
+inline bool IsOp(const CallNode* call, const std::string& op_name) {
+  const auto* op_node = call->op.as<OpNode>();
+  if (!op_node) return false;
+  Op op = GetRef<Op>(op_node);
+  return op == Op::Get(op_name);
+}
+
+/*!
+ * \brief Return a call node within the function which calls an op with the 
given name
+ * The function must contain exactly one call to such op.
+ * \param f The function to look for an op.
+ * \param op_name The name of the op
+ * \return A call node which calls an op with the given name
+ */
+inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) 
{
+  auto local_bindings = AnalyzeVar2Value(f);
+  for (const auto& entry : local_bindings) {
+    if (auto call = entry.second.as<CallNode>(); call && backend::IsOp(call, 
op_name)) {
+      return call;
+    }
+  }
+  LOG(FATAL) << op_name << " not found in the function:\n" << f;
+  return nullptr;
+}
+
+}  // namespace backend
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_BACKEND_CONTRIB_UTILS_H_
diff --git a/src/relax/transform/run_codegen.cc 
b/src/relax/transform/run_codegen.cc
new file mode 100644
index 0000000000..114b7d2a34
--- /dev/null
+++ b/src/relax/transform/run_codegen.cc
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/run_codegen.cc
+ * \brief Run codegen for annotated relax functions.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+
+#include <iostream>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+class CodeGenRunner : ExprMutator {
+ public:
+  using OptionMap = Map<String, ObjectRef>;
+
+  explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {}
+
+  IRModule Run(Optional<Map<String, OptionMap>> target_options, Array<String> 
entry_functions) {
+    IRModule mod = builder_->GetContextIRModule();
+    for (const String& entry_func_name : entry_functions) {
+      auto entry_func = mod->Lookup(entry_func_name);
+      auto gvar = mod->GetGlobalVar(entry_func_name);
+      builder_->UpdateFunction(gvar, 
Downcast<BaseFunc>(VisitExpr(entry_func)));
+    }
+
+    auto ext_mods = InvokeCodegen(mod, target_options.value_or({}));
+    auto out_mod = builder_->GetContextIRModule();
+
+    if (ext_mods.size()) {
+      out_mod = WithAttr(out_mod, tvm::attr::kExternalMods, 
std::move(ext_mods));
+    }
+
+    if (constant_names.size()) {
+      // Some backends (e.g. TensorRT) expect constants to be passed when they 
are instantiated
+      Map<String, runtime::NDArray> constants;
+      for (const auto& [constant, name] : constant_names) {
+        ICHECK(!constants.count(name)) << "More than one constant with the 
name " << name;
+        constants.Set(name, constant->data);
+      }
+      out_mod = WithAttr(out_mod, tvm::attr::kConstNameToConstant, 
std::move(constants));
+    }
+
+    // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a 
better way to handle this.
+    return RemoveUnusedFunctions(out_mod, entry_functions);
+  }
+
+  using ExprMutator::VisitExpr_;
+
+  Expr VisitExpr_(const CallNode* call_node) override {
+    auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
+    if (auto const* gvar_node = call_node->op.as<GlobalVarNode>()) {
+      const GlobalVar gvar = GetRef<GlobalVar>(gvar_node);
+
+      auto create_call_tir = [call_node, this](Expr extern_func, StructInfo 
ret_struct_info) {
+        Array<Expr> new_args({extern_func});
+        new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return 
VisitExpr(arg); })));
+
+        static const Op& call_op = Op::Get("relax.call_tir");
+
+        return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info});
+      };
+
+      if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
+        return create_call_tir(it->second.first, it->second.second);
+      } else {
+        // TODO(@sunggg): Is there any better way to get this func?
+        Function func = 
Downcast<Function>(builder_->GetContextIRModule()->Lookup(gvar));
+        Expr new_func = VisitExpr(func);
+
+        if (new_func->IsInstance<ExternFuncNode>()) {
+          extern_funcs_[gvar_node] = {new_func, func->ret_struct_info};
+          // Remove the global symbol and codegen attributes from the function 
so that it can be
+          // removed the module.
+          static const runtime::PackedFunc* RemoveFuncAttrFunc =
+              runtime::Registry::Get("ir.BaseFuncWithoutAttr");
+          ICHECK(RemoveFuncAttrFunc);
+          func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol);
+          func = (*RemoveFuncAttrFunc)(func, attr::kCodegen);
+          builder_->UpdateFunction(gvar, func);
+          return create_call_tir(new_func, func->ret_struct_info);
+        }
+      }
+    }
+    Array<Expr> new_args;
+    for (const auto& arg : call_node->args) {
+      new_args.push_back(VisitExpr(arg));
+    }
+
+    return Call(call_node->op, new_args, call_node->attrs, 
call_node->sinfo_args, call_node->span);
+  }
+
+  Expr VisitExpr_(const FunctionNode* func_node) override {
+    Function func = GetRef<Function>(func_node);
+    auto opt_codegen = func->GetAttr<String>(attr::kCodegen);
+    if (opt_codegen) {
+      auto ext_symbol = GetExtSymbol(func);
+      size_t count = 0;
+      PostOrderVisit(func->body, [=, &count](Expr e) {
+        if (e->IsInstance<ConstantNode>()) {
+          // Make sure to pick a unique name
+          auto name = ext_symbol + "_" + opt_codegen.value() + "_const_" + 
std::to_string(count++);
+          auto constant = Downcast<Constant>(e);
+          constant_names.Set(constant, name);
+        }
+      });
+      return ExternFunc(GetExtSymbol(func));
+    } else {
+      return ExprMutator::VisitExpr_(func_node);
+    }
+  }
+
+ private:
+  Array<runtime::Module> InvokeCodegen(IRModule mod, Map<String, OptionMap> 
target_options) {
+    std::unordered_map<std::string, Array<Function>> target_functions;
+
+    for (const auto& entry : mod->functions) {
+      PostOrderVisit(entry.second, [&target_functions](Expr e) {
+        if (e->IsInstance<FunctionNode>()) {
+          auto f = Downcast<Function>(e);
+          if (auto target_opt = f->GetAttr<String>(attr::kCodegen)) {
+            String target = target_opt.value();
+            target_functions[target].push_back(f);
+          }
+        }
+      });
+    }
+
+    Array<runtime::Module> ext_mods;
+
+    for (const auto& [target, functions] : target_functions) {
+      OptionMap options = target_options.Get(target).value_or({});
+      // Start the codegen process.
+      // Get the codegen with its ffi key.
+      String codegen_name = "relax.ext." + target;
+      auto codegen = runtime::Registry::Get(codegen_name);
+      ICHECK(codegen) << "Codegen is not found: " << codegen_name << "\n";
+
+      Array<runtime::Module> compiled_functions = (*codegen)(functions, 
options, constant_names);
+      ext_mods.insert(ext_mods.end(), compiled_functions.begin(), 
compiled_functions.end());
+    }
+
+    return ext_mods;
+  }
+
+  /*! \brief The names of all constants in the original module. */
+  Map<Constant, String> constant_names;
+  /*! \brief Extern funcs and their return struct infos for each global 
variable.  */
+  std::unordered_map<const GlobalVarNode*, std::pair<Expr, StructInfo>> 
extern_funcs_;
+};
+
+}  // namespace relax
+
+namespace transform {
+Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>> target_options,
+                Array<String> entry_functions) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule m,
+                                                                            
PassContext pc) {
+    return relax::CodeGenRunner(m).Run(target_options, entry_functions);
+  };
+  return CreateModulePass(pass_func, 0, "RunCodegen", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen);
+
+}  // namespace transform
+}  // namespace tvm
diff --git a/tests/python/relax/test_codegen_tensorrt.py 
b/tests/python/relax/test_codegen_tensorrt.py
new file mode 100644
index 0000000000..164cf3a818
--- /dev/null
+++ b/tests/python/relax/test_codegen_tensorrt.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import numpy as np
+import tvm
+import tvm.testing
+
+from tvm import relax, relay
+from tvm.script import relax as R
+from tvm.relax.dpl import make_fused_bias_activation_pattern, is_op, wildcard
+
+
+def get_relay_residual_block(d_shape, w_shape):
+    data = relay.var("data", shape=d_shape)
+    weight1 = relay.var("weight1", shape=w_shape)
+    weight2 = relay.var("weight2", shape=w_shape)
+    conv1 = relay.nn.relu(
+        relay.nn.conv2d(
+            data=data,
+            weight=weight1,
+            padding=(1, 1),
+        )
+    )
+    conv2d = relay.nn.relu(
+        relay.nn.conv2d(
+            data=conv1,
+            weight=weight2,
+            padding=(1, 1),
+        )
+    )
+    return conv2d + data
+
+
[email protected]_module
+class Conv2dResidualBlock:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), "float32"),
+        weight1: R.Tensor((64, 64, 3, 3), "float32"),
+        weight2: R.Tensor((64, 64, 3, 3), "float32"),
+    ):
+        with R.dataflow():
+            conv1 = relax.op.nn.relu(relax.op.nn.conv2d(data, weight1, 
padding=(1, 1)))
+            conv2 = relax.op.nn.relu(relax.op.nn.conv2d(conv1, weight2, 
padding=(1, 1)))
+            out = relax.op.add(conv2, data)
+            R.output(out)
+
+        return out
+
+
+has_tensorrt = tvm.get_global_func("relax.ext.tensorrt", True)
+
+tensorrt_enabled = pytest.mark.skipif(
+    not has_tensorrt,
+    reason="TENSORRT not enabled.",
+)
+
+pytestmark = [tensorrt_enabled]
+
+
+def test_tensorrt_offload():
+    weight1_np = np.random.randn(64, 64, 3, 3).astype("float32")
+    weight2_np = np.random.randn(64, 64, 3, 3).astype("float32")
+
+    conv_pat = make_fused_bias_activation_pattern(
+        "relax.nn.conv2d", with_bias=False, activation=None
+    )
+    relu_pat = is_op("relax.nn.relu")(wildcard())
+    add_pat = is_op("relax.add")(wildcard(), wildcard())
+
+    patterns = [
+        ("tensorrt.nn.conv2d", conv_pat),
+        ("tensorrt.nn.relu", relu_pat),
+        ("tensorrt.add", add_pat),
+    ]
+
+    params_np = {"weight1": weight1_np, "weight2": weight2_np}
+
+    mod = tvm.transform.Sequential(
+        [
+            relax.transform.BindParams("main", params_np),
+            relax.transform.FuseOpsByPattern(patterns),
+            relax.transform.MergeCompositeFunctions(),
+            relax.transform.RunCodegen(),
+        ]
+    )(Conv2dResidualBlock)
+
+    target = "cuda"
+    dev = tvm.device(target, 0)
+    ex = relax.vm.build(mod, target)
+
+    vm = relax.VirtualMachine(ex, dev)
+    f = vm["main"]
+
+    data_np = np.random.randn(1, 64, 56, 56).astype("float32")
+    out = f(tvm.nd.array(data_np, dev)).numpy()
+
+    relay_mod = tvm.IRModule.from_expr(get_relay_residual_block(data_np.shape, 
weight1_np.shape))
+
+    ref = (
+        relay.create_executor("graph", mod=relay_mod, device=tvm.cpu(0), 
target="llvm")
+        .evaluate()(*[data_np, weight1_np, weight2_np])
+        .numpy()
+    )
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
+
+
+if __name__ == "__main__":
+    test_tensorrt_offload()
diff --git a/tests/python/relax/test_transform_codegen_pass.py 
b/tests/python/relax/test_transform_codegen_pass.py
new file mode 100644
index 0000000000..e50ad8f5f4
--- /dev/null
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -0,0 +1,260 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import os
+import tvm
+import tvm.testing
+from tvm import relax
+import numpy as np
+from tvm.script import relax as R
+from tvm.relax.testing import transform
+import tempfile
+from tvm.relax.transform.tuning_api import Trace
+from tvm.relax.dpl import is_op, wildcard
+
+env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True)
+env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", 
True)
+
+has_tensorrt_codegen = pytest.mark.skipif(
+    not env_checker_codegen,
+    reason="TensorRT codegen not available",
+)
+has_tensorrt_runtime = pytest.mark.skipif(
+    not env_checker_runtime or not env_checker_runtime(),
+    reason="TensorRT runtime not available",
+)
+
+# Global variable in pytest that applies markers to all tests.
+pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime]
+
+# Target gpu
+target_str = "nvidia/geforce-rtx-3070"  # "nvidia/nvidia-t4"
+target = tvm.target.Target(target_str)
+dev = tvm.cuda()
+
+
+def check_executable(exec, dev, inputs, expected):
+    vm = relax.VirtualMachine(exec, dev)
+    out = vm["main"](*inputs)
+    tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, 
rtol=1e-5)
+
+
+def check_roundtrip(exec0, dev, inputs, expected):
+    exec0.mod.export_library("exec.so")
+    exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so"))
+    os.remove("exec.so")
+    assert exec0.stats() == exec1.stats()
+    assert exec0.as_text() == exec1.as_text()
+
+    check_executable(exec0, dev, inputs, expected)
+    check_executable(exec1, dev, inputs, expected)
+
+
+def gen_ground_truth(mod, target, dev, inputs):
+    # Lower and run tuning
+    # Since there is no default schedule for GPU in MS yet, this is necessary
+    with tempfile.TemporaryDirectory() as work_dir:
+        with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0):
+            seq = tvm.transform.Sequential(
+                [
+                    relax.transform.LegalizeOps(),
+                    relax.transform.MetaScheduleTuneIRMod(
+                        params={}, work_dir=work_dir, max_trials_global=8
+                    ),
+                    relax.transform.MetaScheduleApplyDatabase(work_dir),
+                ]
+            )
+            new_mod = seq(mod)
+    assert relax.analysis.well_formed(new_mod)
+    exec = relax.vm.build(new_mod, target, params={})
+    vm = relax.VirtualMachine(exec, dev)
+    return vm["main"](*inputs)
+
+
[email protected]_module
+class InputModule:
+    @R.function
+    def main(
+        x: R.Tensor((16, 16), "float32"), y: R.Tensor((16, 16), "float32")
+    ) -> R.Tensor((16, 16), "float32"):
+        with R.dataflow():
+            z1 = R.multiply(x, y)
+            z2 = R.add(z1, x)
+            z3 = R.add(z1, z2)
+            z4 = R.multiply(z3, z2)
+            z5 = R.add(z4, z1)
+            R.output(z5)
+        return z5
+
+
+def setup_test():
+    # Prepare IRModule and its input
+    mod = InputModule
+    assert isinstance(mod, tvm.IRModule)
+
+    np0 = np.random.rand(16, 16).astype(np.float32)
+    np1 = np.random.rand(16, 16).astype(np.float32)
+    data0 = tvm.nd.array(np0, dev)
+    data1 = tvm.nd.array(np1, dev)
+    inputs = [data0, data1]
+
+    # Ground truth should be generated before annotation
+    # due to the conflict with MS task extraction
+    # TODO(@sunggg): Sort this out
+    expected = gen_ground_truth(mod, target, dev, inputs)
+    return mod, inputs, expected
+
+
[email protected]_gpu
+def test_tensorrt_only():
+    mod, inputs, expected = setup_test()
+
+    # Define patterns that we want to offload to byoc
+    # This test will offload entire model
+    # Thus, define patterns for both `multiply` and `add` ops
+    patterns = [
+        ("tensorrt.multiply", is_op("relax.multiply")(wildcard(), wildcard())),
+        ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())),
+    ]
+
+    new_mod = tvm.transform.Sequential(
+        [
+            relax.transform.FuseOpsByPattern(patterns),
+            relax.transform.MergeCompositeFunctions(),
+            relax.transform.RunCodegen(),
+        ]
+    )(mod)
+
+    ex0 = relax.vm.build(new_mod, target, params={})
+    # Sanity check for the correctness and rountrip
+    check_roundtrip(ex0, dev, inputs, expected)
+
+
[email protected]_gpu
+def test_mix_use_tensorrt_and_tvm():
+    mod, inputs, expected = setup_test()
+
+    # Define patterns that we want to offload to byoc
+    # This test will only offload `add` op to tensorrt
+    # and tune `multiply` op with MetaSchedule
+    patterns = [
+        ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())),
+    ]
+
+    # Run Codegen pass
+    with tempfile.TemporaryDirectory() as work_dir:
+        with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0):
+            new_mod = tvm.transform.Sequential(
+                [
+                    relax.transform.FuseOpsByPattern(patterns),
+                    relax.transform.MergeCompositeFunctions(),
+                    relax.transform.RunCodegen(),
+                    relax.transform.LegalizeOps(),
+                    relax.transform.MetaScheduleTuneIRMod(
+                        params={}, work_dir=work_dir, max_trials_global=8
+                    ),
+                    relax.transform.MetaScheduleApplyDatabase(work_dir),
+                ]
+            )(mod)
+    assert relax.analysis.well_formed(new_mod)
+    with transform.PassContext(opt_level=0):
+        ex0 = relax.vm.build(new_mod, target, params={})
+
+    # Sanity check for the correctness and rountrip
+    check_roundtrip(ex0, dev, inputs, expected)
+
+
[email protected]_module
+class Conv2dx2:
+    @R.function
+    def main(
+        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+        weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
+    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+        with R.dataflow():
+            lv: R.Tensor((16, 32, 32, 16), dtype="float16") = 
fused_relax_nn_conv2d_tensorrt(
+                data, weight1
+            )
+            gv: R.Tensor((16, 32, 32, 16), dtype="float16") = 
fused_relax_nn_conv2d_tensorrt(
+                lv, weight2
+            )
+            R.output(gv)
+        return gv
+
+    @R.function
+    def fused_relax_nn_conv2d_tensorrt(
+        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+        R.func_attr({"Codegen": "tensorrt", "global_symbol": 
"fused_relax_nn_conv2d_tensorrt"})
+
+        @R.function
+        def gv(
+            data_1: R.Tensor((16, 32, 32, 16), dtype="float16"),
+            weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+        ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+            R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1})
+            with R.dataflow():
+                gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = 
R.nn.conv2d(
+                    data_1,
+                    weight1_1,
+                    padding=[1, 1, 1, 1],
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                )
+                R.output(gv_1)
+            return gv_1
+
+        gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1)
+        return gv1
+
+
[email protected]_module
+class Conv2dx2_after:
+    @R.function
+    def main(
+        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+        weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
+    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+        with R.dataflow():
+            lv = R.call_tir(
+                "fused_relax_nn_conv2d_tensorrt",
+                (data, weight1),
+                out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
+            )
+            gv = R.call_tir(
+                "fused_relax_nn_conv2d_tensorrt",
+                (lv, weight2),
+                out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
+            )
+            R.output(gv)
+        return gv
+
+
+def test_multiple_calls_same_extern():
+    mod = relax.transform.RunCodegen()(Conv2dx2)
+    tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])
+
+
+# TODO(@sunggg):  test with more complex patterns (e.g., multiple annots, 
mixed codegens, different ops, const binding)
+
+if __name__ == "__main__":
+    pytest.main([__file__])

Reply via email to