This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 5e2e942d5a [Unity][Transform] Automatic Mixed Precision (#14242)
5e2e942d5a is described below

commit 5e2e942d5a584f07f30c285ba3391bc51a20890a
Author: Bohan Hou <[email protected]>
AuthorDate: Sun Mar 19 12:05:42 2023 -0400

    [Unity][Transform] Automatic Mixed Precision (#14242)
    
    This PR adds a new pass ToMixedPrecision to automatically cast fp32 models 
to fp16 when necessary.
    
    see 
https://github.com/spectrometerHBH/tvm/blob/amp/src/relax/transform/to_mixed_precision.cc#L51
 on how this pass works.
---
 include/tvm/relax/attrs/datatype.h                 |   9 +
 include/tvm/relax/transform.h                      |   9 +-
 python/tvm/relax/op/datatype.py                    |  17 +
 python/tvm/relax/transform/transform.py            |  46 +-
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 src/relax/op/image/resize.cc                       |   3 +-
 src/relax/op/nn/convolution.cc                     |  12 +-
 src/relax/op/nn/nn.cc                              |   9 +-
 src/relax/op/nn/pooling.cc                         |   9 +-
 src/relax/op/op_common.h                           |  12 +-
 src/relax/op/tensor/binary.h                       |  23 +-
 src/relax/op/tensor/create.cc                      |  12 +-
 src/relax/op/tensor/datatype.cc                    |  30 +-
 src/relax/op/tensor/datatype.h                     |   8 +
 src/relax/op/tensor/index.cc                       |   3 +-
 src/relax/op/tensor/linear_algebra.cc              |   8 +-
 src/relax/op/tensor/manipulate.cc                  |  27 +-
 src/relax/op/tensor/ternary.cc                     |   3 +-
 src/relax/transform/infer_amp_utils.cc             |  59 +++
 src/relax/transform/infer_amp_utils.h              |  85 ++++
 src/relax/transform/to_mixed_precision.cc          | 538 ++++++++++++++++++++
 tests/python/relax/test_op_datatype.py             |  17 +
 .../relax/test_transform_to_mixed_precision.py     | 540 +++++++++++++++++++++
 23 files changed, 1408 insertions(+), 73 deletions(-)

diff --git a/include/tvm/relax/attrs/datatype.h 
b/include/tvm/relax/attrs/datatype.h
index 79cb345688..c5a5a4e7d2 100644
--- a/include/tvm/relax/attrs/datatype.h
+++ b/include/tvm/relax/attrs/datatype.h
@@ -38,6 +38,15 @@ struct AstypeAttrs : public tvm::AttrsNode<AstypeAttrs> {
   }
 };  // struct AstypeAttrs.
 
+/*! \brief Attributes used in wrap_param operator */
+struct WrapParamAttrs : public tvm::AttrsNode<WrapParamAttrs> {
+  DataType dtype;
+
+  TVM_DECLARE_ATTRS(WrapParamAttrs, "relax.attrs.WrapParamAttrs") {
+    TVM_ATTR_FIELD(dtype).describe("Target data type");
+  }
+};  // struct WrapParamAttrs.
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 369290e661..eba7de1b0c 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -385,7 +385,6 @@ TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& 
op_impl_map,
  */
 TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
 
-
 /*!
  * \brief Dead code elimination.
  * Currently it removes:
@@ -401,6 +400,14 @@ TVM_DLL Pass ConvertLayout(Map<String, Array<String>> 
desired_layouts);
  */
 TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
 
+/*!
+ * \brief Automatic mixed precision pass. Currently the pass assumes the input 
module to be fp32
+ * only, and will automatically cast fp32 to fp16 for certain ops.
+ * \param out_dtype The output data type of gemm/conv, which is the data type 
of the accumulator.
+ * \return The Pass.
+ */
+TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype);
+
 }  // namespace transform
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/relax/op/datatype.py b/python/tvm/relax/op/datatype.py
index 5c02776dd7..120487c0bd 100644
--- a/python/tvm/relax/op/datatype.py
+++ b/python/tvm/relax/op/datatype.py
@@ -40,3 +40,20 @@ def astype(x: Expr, dtype: Union[str, DataType]) -> Expr:
         The casted result.
     """
     return _ffi_api.astype(x, dtype)  # type: ignore
+
+
+def wrap_param(data: Expr, dtype: Union[str, DataType] = "float32") -> Expr:
+    """Cast input tensor which is model param to data type if the dtype of the 
input data is not
+    the same as the given dtype.
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+    dtype : Union[str, DataType]
+        The target data type
+    Returns
+    -------
+    result : relax.Expr
+        The casted result.
+    """
+    return _ffi_api.wrap_param(data, dtype)  # type: ignore
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 72768bf676..ebfd7a6765 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -657,37 +657,6 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]]) 
-> tvm.ir.transform.Pas
     return _ffi_api.ConvertLayout(desired_layouts)  # type: ignore
 
 
-def AlterOpImpl(
-    op_impl_map: Dict[str, PrimFunc],
-    op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
-):
-    """Replace all PrimFunc's which have matching 'operator_name' attribute, 
with replacement
-    PrimFunc that could possibly have different layouts on i/o buffers. The 
layout
-    transformations on i/o buffers is present in the op_buffer_transforms map. 
Inserts the layout
-    transformations in the call sites of PrimFuncs being replaced to transform 
i/o
-    tensors into expected layout by new PrimFunc.
-
-    Parameters
-    ----------
-    op_impl_map: Dict[str, PrimFunc]
-        op_kind to PrimFunc map
-    op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]
-        op_kind to layout transformation map for each of the buffers
-    Returns
-    -------
-    ret: tvm.ir.transform.Pass
-    """
-    for operator_name, transform_list in op_buffer_transforms.items():
-        l = []
-        for transform in transform_list:
-            if isinstance(transform, Callable):
-                transform = IndexMap.from_func(transform)
-            l.append(transform)
-        op_buffer_transforms[operator_name] = l
-
-    return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms)  # type: 
ignore
-
-
 def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> 
tvm.ir.transform.Pass:
     """Remove dead code in the program.
        Currently it removes:
@@ -715,6 +684,21 @@ def DeadCodeElimination(entry_functions: 
Optional[List[str]] = None) -> tvm.ir.t
     return _ffi_api.DeadCodeElimination(entry_functions)  # type: ignore
 
 
+def ToMixedPrecision(out_dtype="float32") -> tvm.ir.transform.Pass:
+    """Automatic mixed precision pass. Currently the pass assumes the input 
module to be fp32
+    only, and will automatically cast fp32 to fp16 for certain ops.
+    Parameters
+    ----------
+    out_dtype : str
+        The output data type of gemm/conv, which is the data type of the 
accumulator.
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass for mixed precision.
+    """
+    return _ffi_api.ToMixedPrecision(out_dtype)  # type: ignore
+
+
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass."""
 
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index b3190ea334..32d6083e8a 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -121,6 +121,7 @@ from tvm.relax.op import (
     unique,
     vm,
     where,
+    wrap_param,
     zeros,
     zeros_like,
     nn,
@@ -641,6 +642,7 @@ __all__ = [
     "variance",
     "vm",
     "where",
+    "wrap_param",
     "zeros",
     "zeros_like",
     "nn",
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index de6eec6236..6d49bea6b6 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -121,7 +121,8 @@ TVM_REGISTER_OP("relax.image.resize2d")
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("size", "Shape", "The output image shape.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize2D)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutResize2d);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutResize2d)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index f356876620..e10d205b23 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -177,13 +177,23 @@ InferLayoutOutput InferLayoutConv2d(const Call& call,
   return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, 
Attrs(new_attrs));
 }
 
+Call InferMixedPrecisionConv2d(const Call& call, const DataType& out_dtype) {
+  const auto* conv2d_attrs = call->attrs.as<Conv2DAttrs>();
+  return Downcast<Call>(conv2d(call->args[0], call->args[1], 
conv2d_attrs->strides,
+                               conv2d_attrs->padding, conv2d_attrs->dilation, 
conv2d_attrs->groups,
+                               conv2d_attrs->data_layout, 
conv2d_attrs->kernel_layout,
+                               conv2d_attrs->out_layout, out_dtype));
+}
+
 TVM_REGISTER_OP("relax.nn.conv2d")
     .set_num_inputs(2)
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("weight", "Tensor", "The weight tensor.")
     .set_attrs_type<Conv2DAttrs>()
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2d)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv2d);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv2d)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kAlways)
+    .set_attr<FInferMixedPrecision>("FInferMixedPrecision", 
InferMixedPrecisionConv2d);
 
 /* relax.nn.conv2d_transpose */
 TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 6bce51ca50..c3e18f8e3b 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -294,7 +294,8 @@ TVM_REGISTER_OP("relax.nn.layer_norm")
     .add_argument("gamma", "Tensor", "The gamma scale factor.")
     .add_argument("beta", "Tensor", "The beta offset factor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutLayerNorm);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutLayerNorm)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.nn.group_norm */
 TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
@@ -404,7 +405,8 @@ TVM_REGISTER_OP("relax.nn.group_norm")
     .add_argument("gamma", "Tensor", "The gamma scale factor.")
     .add_argument("beta", "Tensor", "The beta offset factor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutGroupNorm);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutGroupNorm)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.nn.dropout */
 TVM_REGISTER_NODE_TYPE(DropoutAttrs);
@@ -429,7 +431,8 @@ TVM_REGISTER_OP("relax.nn.dropout")
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "Input to which dropout will be applied.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDropout)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.nn.cross_entropy_with_logits */
 StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& 
ctx) {
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index be0a794dee..c31ce3dd0b 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -139,7 +139,8 @@ TVM_REGISTER_OP("relax.nn.max_pool2d")
     .add_argument("data", "Tensor", "The input tensor")
     .set_attrs_type<Pool2DAttrs>()
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
                 Array<IntImm> dilation, bool ceil_mode, String layout,
@@ -155,7 +156,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d")
     .add_argument("data", "Tensor", "The input tensor")
     .set_attrs_type<Pool2DAttrs>()
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.nn.adaptive_avg_pool2d */
 TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs);
@@ -237,7 +239,8 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoAdaptiveAvgPool2D)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", 
InferLayoutAdaptiveAvgPool2D);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", 
InferLayoutAdaptiveAvgPool2D)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index ece4c4a321..bd5f2cd4d5 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -34,6 +34,7 @@
 #include <utility>
 #include <vector>
 
+#include "../transform/infer_amp_utils.h"
 #include "../transform/infer_layout_utils.h"
 
 namespace tvm {
@@ -70,11 +71,12 @@ inline TensorStructInfo GetUnaryInputTensorStructInfo(const 
Call& call, const Bl
  * \param OpRegName The name of operator to register. The name passed in will
  * be prepended with a prefix "relax." as the identifier string in the 
operator registry.
  */
-#define RELAX_REGISTER_UNARY_OP(OpRegName)              \
-  TVM_REGISTER_OP("relax." OpRegName)                   \
-      .set_num_inputs(1)                                \
-      .add_argument("x", "Tensor", "The input tensor.") \
-      .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
+#define RELAX_REGISTER_UNARY_OP(OpRegName)                                     
\
+  TVM_REGISTER_OP("relax." OpRegName)                                          
\
+      .set_num_inputs(1)                                                       
\
+      .add_argument("x", "Tensor", "The input tensor.")                        
\
+      .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise) 
\
+      .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
 
 /*!
  * \brief Quick helper macro to expose a make-function to construct the 
operator.
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index 197110c000..086e37f883 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -37,17 +37,18 @@ namespace relax {
  *  1. be prepended with a prefix "relax.op." as the FFI identifier string for 
the make function,
  *  2. be prepended with a prefix "relax." as the identifier string in the 
operator registry.
  */
-#define RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName)                  \
-  Expr OpName(Expr x1, Expr x2) {                                  \
-    static const Op& op = Op::Get("relax." #OpName);               \
-    return Call(op, {x1, x2}, Attrs(), {});                        \
-  }                                                                \
-  TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \
-  TVM_REGISTER_OP("relax." #OpName)                                \
-      .set_num_inputs(2)                                           \
-      .add_argument("x1", "Tensor", "The first input tensor.")     \
-      .add_argument("x2", "Tensor", "The second input tensor.")    \
-      .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutBinaryEwise)
+#define RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName)                              
 \
+  Expr OpName(Expr x1, Expr x2) {                                              
 \
+    static const Op& op = Op::Get("relax." #OpName);                           
 \
+    return Call(op, {x1, x2}, Attrs(), {});                                    
 \
+  }                                                                            
 \
+  TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName);             
 \
+  TVM_REGISTER_OP("relax." #OpName)                                            
 \
+      .set_num_inputs(2)                                                       
 \
+      .add_argument("x1", "Tensor", "The first input tensor.")                 
 \
+      .add_argument("x2", "Tensor", "The second input tensor.")                
 \
+      .set_attr<FRelaxInferLayout>("FRelaxInferLayout", 
InferLayoutBinaryEwise) \
+      .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
 
 #define RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(OpName)             \
   RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr<FInferStructInfo>( \
diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc
index e8374d1981..d4e5e166b7 100644
--- a/src/relax/op/tensor/create.cc
+++ b/src/relax/op/tensor/create.cc
@@ -82,7 +82,8 @@ TVM_REGISTER_OP("relax.full")
     .set_num_inputs(2)
     .add_argument("shape", "Shape", "The shape of the created tensor.")
     .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the 
value to fill.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFull);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFull)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.full_like */
 Expr full_like(Expr x, Expr fill_value, DataType dtype) {
@@ -119,7 +120,8 @@ TVM_REGISTER_OP("relax.full_like")
     .set_num_inputs(2)
     .add_argument("x", "Tensor", "The input tensor.")
     .add_argument("fill_value", "Tensor", "The scalar value to fill.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFullLike);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFullLike)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 // Structure info inference for ones and zeros
 StructInfo InferStructInfoOnesZeros(const Call& call, const BlockBuilder& ctx) 
{
@@ -175,7 +177,8 @@ TVM_REGISTER_OP("relax.ones")
     .set_attrs_type<InitAttrs>()
     .set_num_inputs(1)
     .add_argument("shape", "Shape", "The shape of the created tensor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesZeros);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesZeros)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 TVM_REGISTER_OP("relax.ones_like")
     .set_attrs_type<InitAttrs>()
@@ -207,7 +210,8 @@ TVM_REGISTER_OP("relax.zeros")
     .set_attrs_type<InitAttrs>()
     .set_num_inputs(1)
     .add_argument("shape", "Shape", "The shape of the created tensor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesZeros);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesZeros)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 TVM_REGISTER_OP("relax.zeros_like")
     .set_attrs_type<InitAttrs>()
diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc
index 349a54ee4d..18747fedcd 100644
--- a/src/relax/op/tensor/datatype.cc
+++ b/src/relax/op/tensor/datatype.cc
@@ -55,7 +55,35 @@ TVM_REGISTER_OP("relax.astype")
     .set_num_inputs(1)
     .add_argument("x", "Tensor", "The input tensor")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAstype)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
+
+/* relax.wrap_param */
+TVM_REGISTER_NODE_TYPE(WrapParamAttrs);
+
+Expr MakeWrapParam(Expr data, DataType dtype) {
+  ObjectPtr<WrapParamAttrs> attrs = make_object<WrapParamAttrs>();
+  attrs->dtype = dtype;
+
+  static const Op& op = Op::Get("relax.wrap_param");
+  return Call(op, {std::move(data)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.wrap_param").set_body_typed(MakeWrapParam);
+
+StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx) 
{
+  TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<WrapParamAttrs>();
+  ObjectPtr<TensorStructInfoNode> new_sinfo = 
make_object<TensorStructInfoNode>(*sinfo.get());
+  new_sinfo->dtype = attrs->dtype;
+  return TensorStructInfo(new_sinfo);
+}
+
+TVM_REGISTER_OP("relax.wrap_param")
+    .set_attrs_type<WrapParamAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoWrapParam);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/datatype.h b/src/relax/op/tensor/datatype.h
index 6afa7a50d4..b612c45fc9 100644
--- a/src/relax/op/tensor/datatype.h
+++ b/src/relax/op/tensor/datatype.h
@@ -39,6 +39,14 @@ namespace relax {
  */
 Expr astype(Expr x, DataType dtype);
 
+/*!
+ * \brief A wrapper to wrap the input const tensor to the given data type.
+ * \param x The input const tensor to the operator.
+ * \param dtype The target data type
+ * \return The wrapped result.
+ */
+Expr wrap_param(Expr x, DataType dtype);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 29f668ccf3..218de6e2c6 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -210,7 +210,8 @@ TVM_REGISTER_OP("relax.strided_slice")
     .set_num_inputs(1)
     .add_argument("x", "Tensor", "The source tensor to be sliced.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoStridedSlice)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/linear_algebra.cc 
b/src/relax/op/tensor/linear_algebra.cc
index 50b53d0c8e..afcc7fefe7 100644
--- a/src/relax/op/tensor/linear_algebra.cc
+++ b/src/relax/op/tensor/linear_algebra.cc
@@ -113,11 +113,17 @@ StructInfo InferStructInfoMatmul(const Call& call, const 
BlockBuilder& ctx) {
   return TensorStructInfo(ShapeExpr(output_shape), out_dtype);
 }
 
+Call InferMixedPrecisionMatmul(const Call& call, const DataType& out_dtype) {
+  return Downcast<Call>(matmul(call->args[0], call->args[1], out_dtype));
+}
+
 TVM_REGISTER_OP("relax.matmul")
     .set_num_inputs(2)
     .add_argument("x1", "Tensor", "The first input tensor.")
     .add_argument("x2", "Tensor", "The second input tensor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMatmul);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMatmul)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kAlways)
+    .set_attr<FInferMixedPrecision>("FInferMixedPrecision", 
InferMixedPrecisionMatmul);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 49f745608f..dbeb6f8d5b 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -107,7 +107,8 @@ TVM_REGISTER_OP("relax.broadcast_to")
     .set_num_inputs(2)
     .add_argument("x", "Tensor", "The input tensor.")
     .add_argument("shape", "Shape", "The target shape.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoBroadcastTo);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBroadcastTo)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.concat */
 TVM_REGISTER_NODE_TYPE(ConcatAttrs);
@@ -301,7 +302,8 @@ TVM_REGISTER_OP("relax.concat")
     .set_num_inputs(1)
     .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConcat)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConcat);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConcat)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.expand_dims */
 TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
@@ -397,7 +399,8 @@ TVM_REGISTER_OP("relax.expand_dims")
     .set_attrs_type<ExpandDimsAttrs>()
     .add_argument("x", "Tensor", "The input tensor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoExpandDims)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutExpandDims);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutExpandDims)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 // Helper function for flatten and reshape.
 PrimExpr ComputeShapeProduct(const Array<PrimExpr>& shape_values) {
@@ -437,7 +440,8 @@ StructInfo InferStructInfoFlatten(const Call& call, const 
BlockBuilder& ctx) {
 TVM_REGISTER_OP("relax.flatten")
     .set_num_inputs(1)
     .add_argument("x", "Tensor", "The input tensor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlatten);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlatten)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.layout_transform */
 TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
@@ -499,7 +503,8 @@ TVM_REGISTER_OP("relax.layout_transform")
     .set_num_inputs(1)
     .set_attrs_type<LayoutTransformAttrs>()
     .add_argument("x", "Tensor", "The input tensor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoLayoutTransform);
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoLayoutTransform)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.permute_dims */
 TVM_REGISTER_NODE_TYPE(PermuteDimsAttrs);
@@ -610,7 +615,8 @@ TVM_REGISTER_OP("relax.permute_dims")
     .set_num_inputs(1)
     .add_argument("x", "Tensor", "The input tensor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPermuteDims)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPermuteDims);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPermuteDims)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.reshape */
 Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
@@ -752,7 +758,8 @@ TVM_REGISTER_OP("relax.reshape")
     .set_num_inputs(2)
     .add_argument("x", "Tensor", "The input tensor.")
     .add_argument("shape", "Shape", "The input new shape.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoReshape);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoReshape)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.split */
 TVM_REGISTER_NODE_TYPE(SplitAttrs);
@@ -885,7 +892,8 @@ TVM_REGISTER_OP("relax.split")
     .set_num_inputs(1)
     .add_argument("x", "Tensor", "The input tensor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSplit)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSplit);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSplit)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 /* relax.squeeze */
 TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
@@ -1040,7 +1048,8 @@ TVM_REGISTER_OP("relax.squeeze")
     .set_attrs_type<SqueezeAttrs>()
     .add_argument("x", "Tensor", "The input tensor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSqueeze)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSqueeze);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSqueeze)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 void CheckCollapseShape(const Call& call, const BlockBuilder& ctx,
                         const Array<PrimExpr>& data_shape, const 
Array<PrimExpr>& target_shape) {
diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc
index 93652f43ef..940192bd8e 100644
--- a/src/relax/op/tensor/ternary.cc
+++ b/src/relax/op/tensor/ternary.cc
@@ -111,7 +111,8 @@ TVM_REGISTER_OP("relax.ewise_fma")
     .add_argument("x2", "Tensor", "The right hand operand of the 
multiplication")
     .add_argument("x3", "Tensor", "The operand of the addition")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEwiseFMA)
-    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutEwiseFMA);
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutEwiseFMA)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
 Expr ewise_fma(Expr x1, Expr x2, Expr x3) {
   static const Op& op = Op::Get("relax.ewise_fma");
diff --git a/src/relax/transform/infer_amp_utils.cc 
b/src/relax/transform/infer_amp_utils.cc
new file mode 100644
index 0000000000..330fe9a72a
--- /dev/null
+++ b/src/relax/transform/infer_amp_utils.cc
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "infer_amp_utils.h"
+
+namespace tvm {
+namespace relax {
+
+NType NTypeFrom(const StructInfo& sinfo, DataType dtype) {
+  auto fmapleaf = [&](const StructInfo& sinfo) -> NType {
+    const auto* tensor = sinfo.as<TensorStructInfoNode>();
+    ICHECK(tensor) << "Expected TensorStructInfo, but got " << sinfo;
+    if (dtype == DataType::Void())
+      return NType(DLDataType2String(tensor->dtype));
+    else
+      return NType(DLDataType2String(dtype));
+  };
+  return MapToNestedMsg<String>(sinfo, fmapleaf);
+}
+
+NType NTypeFrom(const Expr& expr, DataType dtype) { return 
NTypeFrom(GetStructInfo(expr), dtype); }
+
+NType NTypeMerge(const NType& a, const NType& b) {
+  auto fcombine = [&](const String& a_str, const String& b_str) -> String {
+    DataType a = DataType(String2DLDataType(a_str));
+    DataType b = DataType(String2DLDataType(b_str));
+    ICHECK_EQ(a.code(), b.code());
+    ICHECK_EQ(a.lanes(), b.lanes());
+    return a.bits() > b.bits() ? a_str : b_str;
+  };
+  return CombineNestedMsg<String>(a, b, fcombine);
+}
+
+Array<ObjectRef> InferMixedPrecisionFollow(const Call& call, const DataType& 
out_dtype) {
+  return {Integer(MixedPrecisionPolicyKind::kFollow), call};
+}
+
+Array<ObjectRef> InferMixedPrecisionNever(const Call& call, const DataType& 
out_dtype) {
+  return {Integer(MixedPrecisionPolicyKind::kNever), call};
+}
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/infer_amp_utils.h 
b/src/relax/transform/infer_amp_utils.h
new file mode 100644
index 0000000000..3c98af6db9
--- /dev/null
+++ b/src/relax/transform/infer_amp_utils.h
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file infer_amp_utils.h
+ * \brief Utility functions to be used in to_mixed_precision pass.
+ */
+
+#ifndef TVM_RELAX_TRANSFORM_INFER_AMP_UTILS_H_
+#define TVM_RELAX_TRANSFORM_INFER_AMP_UTILS_H_
+
+#include <tvm/relax/attrs/nn.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/tir/data_layout.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+using runtime::DLDataType2String;
+using runtime::String;
+using runtime::String2DLDataType;
+
+enum MixedPrecisionPolicyKind : int { kAlways = 0, kFollow = 1, kNever = 2 };
+
+/*! \brief the operator pattern */
+using TMixedPrecisionPolicy = int;
+
+// NType is the message we want to track for vars with nested tensorstructinfo
+// which represents the realization decision of the var.
+// The string is the name of the dtype decision.
+using NType = NestedMsg<String>;
+
+struct NTypeEqual {
+  bool operator()(const NType& a, const NType& b) const {
+    auto dtype_equal = [](const String& a, const String& b) { return a == b; };
+    return Equal(a, b, dtype_equal);
+  }
+};
+
+// Construct a NType from an StructInfo
+NType NTypeFrom(const StructInfo& sinfo, DataType dtype = DataType::Void());
+
+// Construct a NType from an Expr
+NType NTypeFrom(const Expr& expr, DataType dtype = DataType::Void());
+
+// Merge two messages, we keep the higher precision type for each leaf tensor
+NType NTypeMerge(const NType& a, const NType& b);
+
+// The map that notes the NType message of each var
+using VarDTypeMap = std::unordered_map<Var, NType, ObjectPtrHash, 
ObjectPtrEqual>;
+
+// Call is a call node, out_dtype is the expected output_dtype
+using FInferMixedPrecision =
+    runtime::TypedPackedFunc<Call(const Call& call_node, const DataType& 
out_dtype)>;
+
+Array<ObjectRef> InferMixedPrecisionFollow(const Call& call, const DataType& 
out_dtype);
+
+Array<ObjectRef> InferMixedPrecisionNever(const Call& call, const DataType& 
out_dtype);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_TRANSFORM_INFER_AMP_UTILS_H_
diff --git a/src/relax/transform/to_mixed_precision.cc 
b/src/relax/transform/to_mixed_precision.cc
new file mode 100644
index 0000000000..4728f81b63
--- /dev/null
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -0,0 +1,538 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/to_mixed_precision.cc
+ * \brief Automatic mixed precision pass.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/relax/transform.h>
+
+#include <array>
+
+#include "../op/nn/convolution.h"
+#include "../op/tensor/datatype.h"
+#include "../op/tensor/linear_algebra.h"
+#include "infer_amp_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+using runtime::String;
+
+int GetMixedPrecisionInfo(const CallNode* call_node) {
+  const OpNode* op_node = call_node->op.as<OpNode>();
+  if (op_node == nullptr) {
+    return -1;
+  }
+  Op op = GetRef<Op>(op_node);
+  auto attr_map = 
Op::GetAttrMap<TMixedPrecisionPolicy>("TMixedPrecisionPolicy");
+  return attr_map.count(op) ? attr_map[op] : MixedPrecisionPolicyKind::kNever;
+}
+
+/*!
+ * \brief Main logic to automatically cast fp32 input modules to fp16 for 
certain ops.
+ *
+ * Structurally speaking, a Relax function is composed of a series of 
VarBinding and
+ * MatchCast. And a specific class of VarBindings is the basic unit we want to 
rewrite.
+ * Formally, they are of the form:
+ *
+ * var = Call(Op, [args], attrs)
+ *
+ * where Op is a specific op we want to rewrite, and attrs is the attributes 
of the op.
+ * var and args are all exprs with type Tensor or Tuple of Tensors. They might
+ * be vars, constants, or Tuple of vars and constants.
+ * Depending on the properties of the op, we may have 3 different ways to 
rewrite it:
+ *
+ * 1. kAlways: Always cast the args to fp16
+ *    Currently, this is only used for gemm and conv ops (to favor the use of 
TensorCore)
+ *    We always cast the input args to fp16, and the dtype of the accumulator 
is configured
+ *    by the global output_dtype parameter (default to fp32). We cast the 
output to fp16.
+ *
+ * 2. kFollow: If any of the args if fp32, cast all args to fp32. Otherwise, 
use fp16.
+ *
+ * 3. kNever: Never cast the args to fp16. Always cast all args to fp32 (the 
original dtype).
+ *    Some ops, such as softmax, have numerical issues when using fp16. We 
will always use fp32
+ *    to ensure the correctness.
+ *
+ * Note that in this case, we will actively cast the arg to fp16 only when 
it's used in kAlways.
+ * This is to ensure that we have numerical stability to the best effort.
+ *
+ * DTypeDecisionCollector:
+ *   Note that if some tensor is only used in kAlways ops, we can store it in 
fp16 without worsening
+ *   numerical stability or using more storage. We use a backward propagation 
pass to detect such
+ *   tensors. We will store the information of each var in the only_fp16_map_.
+ *
+ *   We reuse the NTtype struct to store the information of each var. There 
are 3 kinds of info:
+ *     - Unknown (Float0): we never encounter a use of this tensor
+ *     - Float16: we only encounter uses of this tensor in kAlways ops
+ *     - Float32: we encounter some use of this tensor outside of kAlways ops
+ *   The info value forms a semi-lattice, where Float8 is the top, Float16 is 
the middle, and
+ *   Float32 is the bottom. The lower bound of two info values is the one with 
more bits.
+ *
+ * ToMixedPrecisionRewriter:
+ *   We will then use a forward propagation pass to rewrite the program. Since 
we only keep one
+ *   specific data type for each var, and we will cast the var to the required 
dtype locally when we
+ *   encounter its use if needed. Note that we may cast the var to some 
certain dtype multiple
+ *   times, but we decide not to store and reuse the casted copy due to the 
storage concern and to
+ *   be more friendly to inlining and operator fusion. We will store the var 
to fp16 if it's only
+ *   used in kAlways ops, otherwise we will store it as the natural output 
dtype of the op.
+ *
+ * The information of each op is registered in the
+ * Op::GetAttr<FInferMixedPrecision>("FInferMixedPrecision"). The registered 
function has signature:
+ * FInferMixedPrecision. We will call the registered function with the 
original call and the global
+ * output_dtype parameter. The registered function will return the policy of 
the op, whether the op
+ * can adjust the dtype of the accumulator, and the new call node with 
output_dtype set to the
+ * global output_dtype parameter.
+ *
+ * Key design: wrap_param op
+ *   We need to use fp16 parameters (which appear as constants in the 
program), but the type
+ *   inference will fail if some parameters are fp16 and some are fp32 in the 
original module. To
+ *   solve this, we introduce a new op wrap_param, which will wrap the 
original parameter and cast
+ *   it to fp32 var.
+ *
+ *   When we encounter the var afterwards, we will directly replace it with 
the parameter. This
+ *   information is tracked by the const_map_.
+ */
+class DTypeDecisionCollector : public ExprVisitor {
+ public:
+  explicit DTypeDecisionCollector(DataType output_dtype) : 
output_dtype_(output_dtype) {}
+
+  static VarDTypeMap Collect(Function func, DataType output_dtype) {
+    DTypeDecisionCollector collector(output_dtype);
+    collector.VisitExpr(func);
+    return std::move(collector.only_fp16_map_);
+  }
+
+ private:
+  NType GetDType(const Var& var) {
+    auto it = only_fp16_map_.find(var);
+    if (it == only_fp16_map_.end()) {
+      // we never encounter this var before
+      NType unknown = NTypeFrom(var, unknown_);
+      only_fp16_map_[var] = unknown;
+      return unknown;
+    }
+    return it->second;
+  }
+
+  // merge the message for a var
+  void UpdateVarDTypeMap(const Var& var, const NType& dtype) {
+    auto it = only_fp16_map_.find(var);
+    if (it == only_fp16_map_.end()) {
+      only_fp16_map_[var] = dtype;
+    } else {
+      only_fp16_map_[var] = NTypeMerge(it->second, dtype);
+    }
+  }
+
+  // merge the message for all vars in the expr list
+  void RequireArgsToType(Array<Expr> args, Array<NType> to) {
+    ICHECK(args.size() == to.size()) << "Invalid target dtypes";
+    for (size_t i = 0; i < args.size(); ++i) {
+      auto fvisitleaf = [&](const Expr& expr, NType to) {
+        if (const auto* var = expr.as<VarNode>()) {
+          UpdateVarDTypeMap(GetRef<Var>(var), to);
+        } else if (expr->IsInstance<ConstantNode>()) {
+          // Constant can be casted anyway, so we don't need to do anything 
here
+          return;
+        } else {
+          LOG(FATAL) << "Unsupported argument type: " << expr->GetTypeKey();
+        }
+      };
+      DecomposeNestedMsg(args[i], to[i], fvisitleaf);
+    }
+  }
+
+  // merge the message for all vars in the expr list
+  void RequireArgsToType(Array<Expr> args, DataType to) {
+    std::vector<Expr> arg_arr;
+    std::vector<NType> to_arr;
+    for (const Expr& arg : args) {
+      if (IsNestedTensor(arg)) {
+        // only require the nested tensor args
+        arg_arr.push_back(arg);
+        to_arr.push_back(NTypeFrom(arg, to));
+      }
+    }
+    RequireArgsToType(std::move(arg_arr), std::move(to_arr));
+  }
+
+  void VisitVars_(const VarNode* op) {
+    Var var = GetRef<Var>(op);
+    if (IsNestedTensor(var)) {
+      // require the var to be fp32 (its original dtype)
+      UpdateVarDTypeMap(var, NTypeFrom(var, fp32_));
+      return;
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final { VisitVars_(op); }
+
+  void VisitExpr_(const DataflowVarNode* op) final { VisitVars_(op); }
+
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) 
final {
+    auto policy = GetMixedPrecisionInfo(call_node);
+    if (policy == -1) {
+      ExprVisitor::VisitBinding_(binding, call_node);
+      return;
+    }
+    if (policy == kAlways) {
+      // require inputs to be fp16
+      RequireArgsToType(call_node->args, fp16_);
+    } else if (policy == kFollow || policy == kNever) {
+      // require inputs to be fp32 (the original dtype)
+      RequireArgsToType(call_node->args, fp32_);
+    } else {
+      LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy;
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleNode* 
tuple_node) final {
+    // require input fields to be the type of the lhs field respectively
+    NType lhs_type = GetDType(binding->var);
+    RequireArgsToType(tuple_node->fields, lhs_type.NestedArray());
+  }
+
+  void VisitBinding_(const VarBindingNode* binding,
+                     const TupleGetItemNode* tuple_get_item_node) final {
+    // require the i-th field rhs tuple to be the type of the lhs
+    NType lhs_type = GetDType(binding->var);
+    std::vector<NType> require_rhs;
+    const TupleStructInfoNode* sinfo =
+        tuple_get_item_node->tuple->struct_info_.as<TupleStructInfoNode>();
+    ICHECK(sinfo != nullptr) << "TupleGetItemNode must have TupleStructInfo";
+    for (size_t i = 0; i < sinfo->fields.size(); ++i) {
+      if (i == (size_t)tuple_get_item_node->index) {
+        require_rhs.push_back(lhs_type);
+      } else {
+        require_rhs.push_back(NTypeFrom(sinfo->fields[i], unknown_));
+      }
+    }
+    RequireArgsToType({tuple_get_item_node->tuple}, {NType(require_rhs)});
+  }
+
+  // override the following methods to visit in backward order
+  void VisitExpr_(const SeqExprNode* op) final {
+    this->VisitSpan(op->span);
+    this->VisitExpr(op->body);
+    for (auto it = op->blocks.rbegin(); it != op->blocks.rend(); it++) {
+      this->VisitBindingBlock(*it);
+    }
+
+    if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) {
+      this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
+    }
+  }
+
+  void VisitBindingBlock_(const BindingBlockNode* block) { return; }
+
+  void VisitBindingBlock_(const DataflowBlockNode* block) {
+    for (auto it = block->bindings.rbegin(); it != block->bindings.rend(); 
it++) {
+      this->VisitBinding(*it);
+    }
+  }
+
+  void VisitExpr_(const IfNode* op) final {
+    this->VisitSpan(op->span);
+    this->VisitExpr(op->true_branch);
+    this->VisitExpr(op->false_branch);
+    this->VisitExpr(op->cond);
+
+    if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) {
+      this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
+    }
+  }
+
+  DataType unknown_ = DataType(DataType::TypeCode::kFloat, 0, 1);
+  DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1);
+  DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1);
+  DataType output_dtype_;
+  VarDTypeMap only_fp16_map_;
+};
+
+class ToMixedPrecisionRewriter : public ExprMutator {
+ public:
+  explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType 
output_dtype)
+      : only_fp16_map_(only_fp16_map), output_dtype_(output_dtype) {}
+
+ private:
+  Var GetRemapped(const Var& var) {
+    auto it = var_remap_.find(var->vid);
+    return it == var_remap_.end() ? var : it->second;
+  }
+
+  Array<Expr> RemapArgs(const Array<Expr>& args) {
+    Array<Expr> new_args;
+    for (const auto& arg : args) {
+      new_args.push_back(VarReplacer::Replace(arg, var_remap_));
+    }
+    return new_args;
+  }
+
+  // Util function to rewrite the expr to the given dtype
+  // rewrite each leaf tensor to the given dtype if necessary
+  // Note that this function only accepts expr with nested tensor type
+  Expr RewriteExpr(const Expr& expr, const NType& to) {
+    auto fvisitleaf = [&](const Expr& expr, std::array<NType, 1> to) -> Expr {
+      const auto* tensor = GetStructInfoAs<TensorStructInfoNode>(expr);
+      ICHECK(tensor != nullptr) << "Only support rewriting tensor expr";
+      // We only rewrite the expr if the dtype is not the same as the given 
dtype
+      if (NTypeEqual()(to[0], NTypeFrom(expr))) return expr;
+      // We only rewrite the expr if the dtype is fp16 or fp32, dtypes such as 
int32, float64 is not
+      // supported to be rewritten
+      if (tensor->dtype != fp16_ && tensor->dtype != fp32_) return expr;
+      return astype(expr, DataType(String2DLDataType(to[0].LeafValue())));
+    };
+    return TransformTupleLeaf<String>(expr, std::array<NType, 1>({to}), 
fvisitleaf);
+  }
+
+  Array<Expr> RewriteArgs(const Array<Expr>& args, DataType to) {
+    Array<Expr> new_args;
+    for (const Expr& arg : args) {
+      if (IsNestedTensor(arg)) {
+        new_args.push_back(RewriteExpr(arg, NTypeFrom(arg, to)));
+      } else {
+        new_args.push_back(arg);
+      }
+    }
+    return new_args;
+  }
+
+  // Util function to check if any of the tensors in the args is fp32
+  bool AnyArgIsFP32(const NType& cur_type) {
+    bool result = false;
+    auto fvisitleaf = [&, this](const String& dtype) {
+      if (dtype == "float32") {
+        result = true;
+      }
+    };
+    ForEachLeaf<String>(cur_type, fvisitleaf);
+    return result;
+  }
+
+  bool AnyArgIsFP32(const Array<Expr>& args) {
+    for (const Expr& arg : args) {
+      if (IsNestedTensor(arg)) {
+        if (AnyArgIsFP32(NTypeFrom(arg))) return true;
+      }
+    }
+    return false;
+  }
+
+  void CastIfFp16Only(const Var& var) {
+    ICHECK(builder_->CurrentBlockIsDataFlow());
+    // Get the current remapped var
+    Var cur_var = GetRemapped(var);
+    // Store the tensors that are fp16 only to fp16
+    auto it = only_fp16_map_->find(var);
+    if (it == only_fp16_map_->end()) return;
+    // Get the to dtype, cast to fp16 if the var is fp16 only, otherwise do 
nothing
+    auto fcombine = [](const String& from, const String& required) -> String {
+      return required == "float16" ? required : from;
+    };
+    NType from = NTypeFrom(cur_var);
+    NType to = CombineNestedMsg<String>(from, it->second, fcombine);
+    Expr rewrite = RewriteExpr(cur_var, to);
+    // If cur_var is not rewritten, we don't need to emit a new var
+    if (!rewrite.same_as(cur_var)) {
+      // Emit a new var, and update the var remap
+      var_remap_[var->vid] = builder_->Emit(rewrite);
+    }
+  }
+
+  Expr VisitVar_(const Var& var) {
+    // We rewrite the remapped var to the original dtype
+    auto it = var_remap_.find(var->vid);
+    if (it != var_remap_.end()) {
+      return RewriteExpr(it->second, NTypeFrom(var));
+    }
+    return var;
+  }
+
+  Expr VisitExpr_(const VarNode* op) final {
+    if (!builder_->CurrentBlockIsDataFlow()) {
+      return ExprMutator::VisitExpr_(op);
+    }
+    return VisitVar_(GetRef<Var>(op));
+  }
+
+  Expr VisitExpr_(const DataflowVarNode* op) final {
+    if (!builder_->CurrentBlockIsDataFlow()) {
+      return ExprMutator::VisitExpr_(op);
+    }
+    return VisitVar_(GetRef<Var>(op));
+  }
+
+  void VisitBinding(const Binding& binding) {
+    ExprMutator::VisitBinding(binding);
+    if (!builder_->CurrentBlockIsDataFlow()) return;
+    CastIfFp16Only(binding->var);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) 
final {
+    if (!builder_->CurrentBlockIsDataFlow()) {
+      ExprMutator::VisitBinding_(binding, call_node);
+      return;
+    }
+    auto policy = GetMixedPrecisionInfo(call_node);
+    if (policy == -1) {
+      // not an op call
+      ExprMutator::VisitBinding_(binding, call_node);
+      return;
+    }
+    // var = Call(op)
+    const auto* op_node = call_node->op.as<OpNode>();
+    ICHECK(op_node != nullptr);
+    Op op = GetRef<Op>(op_node);
+    if (wrap_param_op.same_as(op)) {
+      // wrap_param
+      ReEmitBinding(binding, call_node->args[0]);
+      return;
+    }
+    DataType to;
+    ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node);
+    // We first to remap the args to the current vars according to the 
var_remap_
+    new_call->args = std::move(RemapArgs(call_node->args));
+    // Then we rewrite the args according to the policy
+    if (policy == kAlways) {
+      to = fp16_;
+      auto attr_map = 
Op::GetAttrMap<FInferMixedPrecision>("FInferMixedPrecision");
+      ICHECK(attr_map.count(op));
+      auto f = attr_map[op];
+      new_call = make_object<CallNode>(*(f(Call(new_call), 
output_dtype_).get()));
+    } else if (policy == kFollow) {
+      to = AnyArgIsFP32(new_call->args) ? fp32_ : fp16_;
+    } else if (policy == kNever) {
+      to = fp32_;
+    } else {
+      LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy;
+    }
+    new_call->args = std::move(RewriteArgs(new_call->args, to));
+    new_call->struct_info_ = NullOpt;
+    Expr new_value = builder_->Normalize(Call(new_call));
+    if (policy == kAlways && binding->var->IsInstance<DataflowVarNode>()) {
+      // kAlways: store the tensors to fp16
+      // But global vars will be stored to the original dtype anyway (see 
below)
+      new_value = RewriteExpr(new_value, NTypeFrom(new_value, fp16_));
+    }
+    if (!binding->var->IsInstance<DataflowVarNode>()) {
+      // Global var: store the tensors to the original dtype
+      NType to = NTypeFrom(binding->var);
+      new_value = RewriteExpr(new_value, to);
+    }
+    ReEmitBinding(binding, builder_->Normalize(new_value));
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleNode* 
tuple_node) final {
+    if (!builder_->CurrentBlockIsDataFlow()) {
+      ExprMutator::VisitBinding_(binding, tuple_node);
+      return;
+    }
+    ObjectPtr<TupleNode> new_tuple = make_object<TupleNode>(*tuple_node);
+    new_tuple->fields = std::move(RemapArgs(tuple_node->fields));
+    new_tuple->struct_info_ = NullOpt;
+    Expr new_value = builder_->Normalize(Tuple(new_tuple));
+    if (!binding->var->IsInstance<DataflowVarNode>()) {
+      // Global var: store the tensors to the original dtype
+      NType to = NTypeFrom(binding->var);
+      new_value = RewriteExpr(new_value, to);
+    }
+    ReEmitBinding(binding, builder_->Normalize(new_value));
+  }
+
+  void VisitBinding_(const VarBindingNode* binding,
+                     const TupleGetItemNode* tuple_get_item_node) final {
+    if (!builder_->CurrentBlockIsDataFlow()) {
+      // We don't need to rewrite the tuple_get_item in dataflow block
+      ExprMutator::VisitBinding_(binding, tuple_get_item_node);
+      return;
+    }
+    ObjectPtr<TupleGetItemNode> new_tuple_get_item =
+        make_object<TupleGetItemNode>(*tuple_get_item_node);
+    new_tuple_get_item->tuple = RemapArgs({tuple_get_item_node->tuple})[0];
+    new_tuple_get_item->struct_info_ = NullOpt;
+    Expr new_value = TupleGetItem(new_tuple_get_item);
+    if (!binding->var->IsInstance<DataflowVarNode>()) {
+      // Global var: store the tensors to the original dtype
+      NType to = NTypeFrom(binding->var);
+      new_value = RewriteExpr(new_value, to);
+    }
+    ReEmitBinding(binding, builder_->Normalize(new_value));
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) {
+    builder_->BeginDataflowBlock();
+    // prepare local versions of params here, if they are fp16 expected only
+    for (auto param : params_) {
+      CastIfFp16Only(param);
+    }
+    for (auto binding : block->bindings) {
+      this->VisitBinding(binding);
+    }
+    for (auto param : params_) {
+      // remove the local version of params
+      auto it = var_remap_.find(param->vid);
+      if (it != var_remap_.end()) {
+        var_remap_.erase(it);
+      }
+    }
+    return builder_->EndBlock();
+  }
+
+  Expr VisitExpr_(const FunctionNode* op) final {
+    params_ = op->params;
+    return ExprMutator::VisitExpr_(op);
+  }
+
+  const VarDTypeMap* only_fp16_map_;
+
+  DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1);
+  DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1);
+  DataType output_dtype_;
+  Array<Var> params_;
+
+  const Op& wrap_param_op = Op::Get("relax.wrap_param");
+};
+
+Expr ToMixedPrecision(const Function& f, const DataType& out_dtype) {
+  VarDTypeMap only_fp16_map = std::move(DTypeDecisionCollector::Collect(f, 
out_dtype));
+  ToMixedPrecisionRewriter mutator(&only_fp16_map, out_dtype);
+  return mutator(f);
+}
+
+namespace transform {
+
+Pass ToMixedPrecision(const DataType& out_dtype) {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(ToMixedPrecision(f, out_dtype));
+      };
+  return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ToMixedPrecision").set_body_typed(ToMixedPrecision);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_op_datatype.py 
b/tests/python/relax/test_op_datatype.py
index 56bbe464cf..48820b9e2e 100644
--- a/tests/python/relax/test_op_datatype.py
+++ b/tests/python/relax/test_op_datatype.py
@@ -14,6 +14,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import numpy as np  # type: ignore
+
+
 import pytest
 import tvm
 import tvm.testing
@@ -25,7 +28,9 @@ from tvm.script import relax as R
 
 def test_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    c = relax.Constant(tvm.nd.array(np.array([1, 2, 3], dtype="float16")))
     assert relax.op.astype(x, "float16").op == Op.get("relax.astype")
+    assert relax.op.wrap_param(c, "float32").op == Op.get("relax.wrap_param")
 
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
@@ -101,5 +106,17 @@ def test_astype_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.astype(x1, "float16"))
 
 
+def test_wrap_param_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Constant(tvm.nd.array(np.zeros([1, 2, 3], dtype="float16")))
+    x1 = relax.Constant(tvm.nd.array(np.zeros([1, 2, 3], dtype="int8")))
+    _check_inference(
+        bb, relax.op.wrap_param(x0, "float32"), relax.TensorStructInfo((1, 2, 
3), "float32")
+    )
+    _check_inference(
+        bb, relax.op.wrap_param(x1, "int32"), relax.TensorStructInfo((1, 2, 
3), "int32")
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py 
b/tests/python/relax/test_transform_to_mixed_precision.py
new file mode 100644
index 0000000000..b9409bff52
--- /dev/null
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -0,0 +1,540 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import tvm
+from tvm import relax
+import tvm.testing
+from tvm.relax.transform import ToMixedPrecision
+from tvm.script.parser import ir as I, relax as R
+
+
+def _assert_test(input, expected):
+    mod = ToMixedPrecision()(input)
+    tvm.ir.assert_structural_equal(mod, expected)
+
+
+def test_conv2d():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, 
dtype="float16")
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, 
dtype="float16")
+                gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                R.output(gv)
+            return gv
+
+    _assert_test(Input, Expected)
+
+
+def test_conv2d_relu():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(lv)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, 
dtype="float16")
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, 
dtype="float16")
+                lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                lv_1: R.Tensor((2, 4, 26, 26), dtype="float16") = 
R.astype(lv2, dtype="float16")
+                lv3: R.Tensor((2, 4, 26, 26), dtype="float16") = 
R.nn.relu(lv_1)
+                gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv3, 
dtype="float32")
+                R.output(gv)
+            return gv
+
+    _assert_test(Input, Expected)
+
+
+def test_relu_conv2d_relu():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                x0: R.Tensor((2, 3, 28, 28), "float32") = R.nn.relu(x)
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, 
dtype="float16")
+                x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x)
+                lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x0, 
dtype="float16")
+                lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+                    lv1,
+                    lv,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv2, 
dtype="float16")
+                lv3: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv)
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv3, 
dtype="float32")
+                R.output(gv2)
+            return gv2
+
+    _assert_test(Input, Expected)
+
+
+def test_conv2d_relu_conv2d():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"),
+            w: R.Tensor((4, 3, 3, 3), "float32"),
+            w2: R.Tensor((4, 4, 3, 3), "float32"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+                gv3: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv2, 
w2, out_dtype="float32")
+                R.output(gv3)
+            return gv3
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            w2: R.Tensor((4, 4, 3, 3), dtype="float32"),
+        ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, 
dtype="float16")
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, 
dtype="float16")
+                lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w2, 
dtype="float16")
+                lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv3, 
dtype="float16")
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv)
+                gv3: R.Tensor((2, 4, 24, 24), dtype="float32") = R.nn.conv2d(
+                    gv2,
+                    lv2,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                R.output(gv3)
+            return gv3
+
+    _assert_test(Input, Expected)
+
+
+def test_gemm_add_silu():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 320), "float32"),
+            w1: R.Tensor((320, 1280), "float32"),
+            w2: R.Tensor((2, 1280), "float32"),
+        ) -> R.Tensor(None, "float32", ndim=2):
+            with R.dataflow():
+                gv0: R.Tensor((2, 1280), "float32") = R.matmul(x, w1, 
out_dtype="float32")
+                gv1: R.Tensor((2, 1280), "float32") = R.add(gv0, w2)
+                gv2: R.Tensor((2, 1280), "float32") = R.nn.silu(gv1)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 320), dtype="float32"),
+            w1: R.Tensor((320, 1280), dtype="float32"),
+            w2: R.Tensor((2, 1280), dtype="float32"),
+        ) -> R.Tensor((2, 1280), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 320), dtype="float16") = R.astype(x, 
dtype="float16")
+                lv1: R.Tensor((320, 1280), dtype="float16") = R.astype(w1, 
dtype="float16")
+                lv2: R.Tensor((2, 1280), dtype="float32") = R.matmul(lv, lv1, 
out_dtype="float32")
+                gv0: R.Tensor((2, 1280), dtype="float16") = R.astype(lv2, 
dtype="float16")
+                lv3: R.Tensor((2, 1280), dtype="float32") = R.astype(gv0, 
dtype="float32")
+                gv1: R.Tensor((2, 1280), dtype="float32") = R.add(lv3, w2)
+                gv2: R.Tensor((2, 1280), dtype="float32") = R.nn.silu(gv1)
+                R.output(gv2)
+            return gv2
+
+    _assert_test(Input, Expected)
+
+
+def test_tuple():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"),
+            w: R.Tensor((4, 3, 3, 3), "float32"),
+            w_2: R.Tensor((4, 4, 3, 3), "float32"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv3 = (gv, gv2)
+                gv4 = (gv3, gv2)
+                gv5 = gv4[0]
+                gv6 = gv5[0]
+                gv7 = R.nn.conv2d(gv6, w_2, out_dtype="float32")
+                R.output(gv7)
+            return gv7
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            w_2: R.Tensor((4, 4, 3, 3), dtype="float32"),
+        ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, 
dtype="float16")
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, 
dtype="float16")
+                lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w_2, 
dtype="float16")
+                lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv3, 
dtype="float16")
+                lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv4, 
dtype="float16")
+                gv3: R.Tuple(
+                    R.Tensor((2, 4, 26, 26), dtype="float16"),
+                    R.Tensor((2, 4, 26, 26), dtype="float16"),
+                ) = (gv, gv2)
+                gv4: R.Tuple(
+                    R.Tuple(
+                        R.Tensor((2, 4, 26, 26), dtype="float16"),
+                        R.Tensor((2, 4, 26, 26), dtype="float16"),
+                    ),
+                    R.Tensor((2, 4, 26, 26), dtype="float16"),
+                ) = (gv3, gv2)
+                gv5: R.Tuple(
+                    R.Tensor((2, 4, 26, 26), dtype="float16"),
+                    R.Tensor((2, 4, 26, 26), dtype="float16"),
+                ) = gv4[0]
+                gv6: R.Tensor((2, 4, 26, 26), dtype="float16") = gv5[0]
+                gv7: R.Tensor((2, 4, 24, 24), dtype="float32") = R.nn.conv2d(
+                    gv6,
+                    lv2,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                R.output(gv7)
+            return gv7
+
+    _assert_test(Input, Expected)
+
+
+def test_concat_matmul():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            lv10: R.Tensor((2, 160), "float32"),
+            lv12: R.Tensor((2, 160), "float32"),
+            w: R.Tensor((320, 1280), "float32"),
+        ) -> R.Tensor(None, "float32", ndim=2):
+            with R.dataflow():
+                lv13: R.Tensor((2, 320), "float32") = R.concat((lv10, lv12), 
axis=-1)
+                lv14: R.Tensor((2, 1280), "float32") = R.matmul(lv13, w, 
out_dtype="float32")
+                R.output(lv14)
+            return lv14
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            lv10: R.Tensor((2, 160), dtype="float32"),
+            lv12: R.Tensor((2, 160), dtype="float32"),
+            w: R.Tensor((320, 1280), dtype="float32"),
+        ) -> R.Tensor((2, 1280), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((320, 1280), dtype="float16") = R.astype(w, 
dtype="float16")
+                lv13: R.Tensor((2, 320), dtype="float32") = R.concat((lv10, 
lv12), axis=-1)
+                lv1: R.Tensor((2, 320), dtype="float16") = R.astype(lv13, 
dtype="float16")
+                lv14: R.Tensor((2, 1280), dtype="float32") = R.matmul(lv1, lv, 
out_dtype="float32")
+                R.output(lv14)
+            return lv14
+
+    _assert_test(Input, Expected)
+
+
+def test_conv2d_softmax():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 3, 26, 26), "float32") = R.nn.conv2d(x, w, 
padding=(1, 1))
+                gv1: R.Tensor((2, 3, 26, 26), "float32") = R.nn.softmax(x, 
axis=1)
+                gv2 = R.add(gv, gv1)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor((2, 3, 26, 26), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3, 3, 3, 3), dtype="float16") = R.astype(w, 
dtype="float16")
+                lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, 
dtype="float16")
+                lv2: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.conv2d(
+                    lv1,
+                    lv,
+                    strides=[1, 1],
+                    padding=[1, 1, 1, 1],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                gv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(lv2, 
dtype="float16")
+                gv1: R.Tensor((2, 3, 28, 28), dtype="float32") = 
R.nn.softmax(x, axis=1)
+                lv3: R.Tensor((2, 3, 28, 28), dtype="float32") = R.astype(gv, 
dtype="float32")
+                gv2: R.Tensor((2, 3, 28, 28), dtype="float32") = R.add(lv3, 
gv1)
+                R.output(gv2)
+            return gv2
+
+    _assert_test(Input, Expected)
+
+
+def test_conv2d_bias_conv2d():
+    @tvm.script.ir_module
+    class Input:
+        @R.function
+        def main(
+            z: R.Tensor((1, 4, 64, 64), dtype="float32"),
+            w0: R.Tensor((512, 4, 3, 3), dtype="float16"),
+            w1: R.Tensor((512,), dtype="float16"),
+            w2: R.Tensor((4, 4, 1, 1), dtype="float16"),
+            w3: R.Tensor((4,), dtype="float16"),
+        ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((512, 4, 3, 3), dtype="float32") = 
R.wrap_param(w0, dtype="float32")
+                lv1: R.Tensor((512,), dtype="float32") = R.wrap_param(w1, 
dtype="float32")
+                lv140: R.Tensor((4, 4, 1, 1), dtype="float32") = 
R.wrap_param(w2, dtype="float32")
+                lv141: R.Tensor((4,), dtype="float32") = R.wrap_param(w3, 
dtype="float32")
+                lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d(
+                    z,
+                    lv140,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                lv143: R.Tensor((1, 4, 1, 1), dtype="float32") = 
R.reshape(lv141, (1, 4, 1, 1))
+                lv144: R.Tensor((1, 4, 64, 64), dtype="float32") = 
R.add(lv142, lv143)
+                lv145: R.Tensor((1, 512, 64, 64), dtype="float32") = 
R.nn.conv2d(
+                    lv144,
+                    lv,
+                    strides=[1, 1],
+                    padding=[1, 1, 1, 1],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                lv146: R.Tensor((1, 512, 1, 1), dtype="float32") = 
R.reshape(lv1, (1, 512, 1, 1))
+                lv147: R.Tensor((1, 512, 64, 64), dtype="float32") = 
R.add(lv145, lv146)
+                gv: R.Tensor((1, 512, 64, 64), dtype="float32") = lv147
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            z: R.Tensor((1, 4, 64, 64), dtype="float32"),
+            w0: R.Tensor((512, 4, 3, 3), dtype="float16"),
+            w1: R.Tensor((512,), dtype="float16"),
+            w2: R.Tensor((4, 4, 1, 1), dtype="float16"),
+            w3: R.Tensor((4,), dtype="float16"),
+        ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 4, 64, 64), dtype="float16") = R.astype(z, 
dtype="float16")
+                lv_1: R.Tensor((512, 4, 3, 3), dtype="float16") = w0
+                lv1: R.Tensor((512,), dtype="float16") = w1
+                lv140: R.Tensor((4, 4, 1, 1), dtype="float16") = w2
+                lv141: R.Tensor((4,), dtype="float16") = w3
+                lv1_1: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv140,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                lv142: R.Tensor((1, 4, 64, 64), dtype="float16") = 
R.astype(lv1_1, dtype="float16")
+                lv143: R.Tensor((1, 4, 1, 1), dtype="float16") = 
R.reshape(lv141, (1, 4, 1, 1))
+                lv144: R.Tensor((1, 4, 64, 64), dtype="float16") = 
R.add(lv142, lv143)
+                lv2: R.Tensor((1, 512, 64, 64), dtype="float32") = R.nn.conv2d(
+                    lv144,
+                    lv_1,
+                    strides=[1, 1],
+                    padding=[1, 1, 1, 1],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                lv145: R.Tensor((1, 512, 64, 64), dtype="float16") = 
R.astype(lv2, dtype="float16")
+                lv146: R.Tensor((1, 512, 1, 1), dtype="float16") = 
R.reshape(lv1, (1, 512, 1, 1))
+                lv147: R.Tensor((1, 512, 64, 64), dtype="float16") = 
R.add(lv145, lv146)
+                gv: R.Tensor((1, 512, 64, 64), dtype="float32") = 
R.astype(lv147, dtype="float32")
+                R.output(gv)
+            return gv
+
+    binding = {
+        "w0": np.random.uniform(size=(512, 4, 3, 3)).astype("float16"),
+        "w1": np.random.uniform(size=(512,)).astype("float16"),
+        "w2": np.random.uniform(size=(4, 4, 1, 1)).astype("float16"),
+        "w3": np.random.uniform(size=(4,)).astype("float16"),
+    }
+    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+    Input = relax.transform.BindParams("main", binding)(Input)
+    Expected = relax.transform.BindParams("main", binding)(Expected)
+    _assert_test(Input, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to