This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 5e2e942d5a [Unity][Transform] Automatic Mixed Precision (#14242)
5e2e942d5a is described below
commit 5e2e942d5a584f07f30c285ba3391bc51a20890a
Author: Bohan Hou <[email protected]>
AuthorDate: Sun Mar 19 12:05:42 2023 -0400
[Unity][Transform] Automatic Mixed Precision (#14242)
This PR adds a new pass ToMixedPrecision to automatically cast fp32 models
to fp16 when necessary.
see
https://github.com/spectrometerHBH/tvm/blob/amp/src/relax/transform/to_mixed_precision.cc#L51
on how this pass works.
---
include/tvm/relax/attrs/datatype.h | 9 +
include/tvm/relax/transform.h | 9 +-
python/tvm/relax/op/datatype.py | 17 +
python/tvm/relax/transform/transform.py | 46 +-
python/tvm/script/ir_builder/relax/ir.py | 2 +
src/relax/op/image/resize.cc | 3 +-
src/relax/op/nn/convolution.cc | 12 +-
src/relax/op/nn/nn.cc | 9 +-
src/relax/op/nn/pooling.cc | 9 +-
src/relax/op/op_common.h | 12 +-
src/relax/op/tensor/binary.h | 23 +-
src/relax/op/tensor/create.cc | 12 +-
src/relax/op/tensor/datatype.cc | 30 +-
src/relax/op/tensor/datatype.h | 8 +
src/relax/op/tensor/index.cc | 3 +-
src/relax/op/tensor/linear_algebra.cc | 8 +-
src/relax/op/tensor/manipulate.cc | 27 +-
src/relax/op/tensor/ternary.cc | 3 +-
src/relax/transform/infer_amp_utils.cc | 59 +++
src/relax/transform/infer_amp_utils.h | 85 ++++
src/relax/transform/to_mixed_precision.cc | 538 ++++++++++++++++++++
tests/python/relax/test_op_datatype.py | 17 +
.../relax/test_transform_to_mixed_precision.py | 540 +++++++++++++++++++++
23 files changed, 1408 insertions(+), 73 deletions(-)
diff --git a/include/tvm/relax/attrs/datatype.h
b/include/tvm/relax/attrs/datatype.h
index 79cb345688..c5a5a4e7d2 100644
--- a/include/tvm/relax/attrs/datatype.h
+++ b/include/tvm/relax/attrs/datatype.h
@@ -38,6 +38,15 @@ struct AstypeAttrs : public tvm::AttrsNode<AstypeAttrs> {
}
}; // struct AstypeAttrs.
+/*! \brief Attributes used in wrap_param operator */
+struct WrapParamAttrs : public tvm::AttrsNode<WrapParamAttrs> {
+ DataType dtype;
+
+ TVM_DECLARE_ATTRS(WrapParamAttrs, "relax.attrs.WrapParamAttrs") {
+ TVM_ATTR_FIELD(dtype).describe("Target data type");
+ }
+}; // struct WrapParamAttrs.
+
} // namespace relax
} // namespace tvm
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 369290e661..eba7de1b0c 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -385,7 +385,6 @@ TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>&
op_impl_map,
*/
TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
-
/*!
* \brief Dead code elimination.
* Currently it removes:
@@ -401,6 +400,14 @@ TVM_DLL Pass ConvertLayout(Map<String, Array<String>>
desired_layouts);
*/
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
+/*!
+ * \brief Automatic mixed precision pass. Currently the pass assumes the input
module to be fp32
+ * only, and will automatically cast fp32 to fp16 for certain ops.
+ * \param out_dtype The output data type of gemm/conv, which is the data type
of the accumulator.
+ * \return The Pass.
+ */
+TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype);
+
} // namespace transform
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/op/datatype.py b/python/tvm/relax/op/datatype.py
index 5c02776dd7..120487c0bd 100644
--- a/python/tvm/relax/op/datatype.py
+++ b/python/tvm/relax/op/datatype.py
@@ -40,3 +40,20 @@ def astype(x: Expr, dtype: Union[str, DataType]) -> Expr:
The casted result.
"""
return _ffi_api.astype(x, dtype) # type: ignore
+
+
+def wrap_param(data: Expr, dtype: Union[str, DataType] = "float32") -> Expr:
+ """Cast input tensor which is model param to data type if the dtype of the
input data is not
+ the same as the given dtype.
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+ dtype : Union[str, DataType]
+ The target data type
+ Returns
+ -------
+ result : relax.Expr
+ The casted result.
+ """
+ return _ffi_api.wrap_param(data, dtype) # type: ignore
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 72768bf676..ebfd7a6765 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -657,37 +657,6 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]])
-> tvm.ir.transform.Pas
return _ffi_api.ConvertLayout(desired_layouts) # type: ignore
-def AlterOpImpl(
- op_impl_map: Dict[str, PrimFunc],
- op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
-):
- """Replace all PrimFunc's which have matching 'operator_name' attribute,
with replacement
- PrimFunc that could possibly have different layouts on i/o buffers. The
layout
- transformations on i/o buffers is present in the op_buffer_transforms map.
Inserts the layout
- transformations in the call sites of PrimFuncs being replaced to transform
i/o
- tensors into expected layout by new PrimFunc.
-
- Parameters
- ----------
- op_impl_map: Dict[str, PrimFunc]
- op_kind to PrimFunc map
- op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]
- op_kind to layout transformation map for each of the buffers
- Returns
- -------
- ret: tvm.ir.transform.Pass
- """
- for operator_name, transform_list in op_buffer_transforms.items():
- l = []
- for transform in transform_list:
- if isinstance(transform, Callable):
- transform = IndexMap.from_func(transform)
- l.append(transform)
- op_buffer_transforms[operator_name] = l
-
- return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms) # type:
ignore
-
-
def DeadCodeElimination(entry_functions: Optional[List[str]] = None) ->
tvm.ir.transform.Pass:
"""Remove dead code in the program.
Currently it removes:
@@ -715,6 +684,21 @@ def DeadCodeElimination(entry_functions:
Optional[List[str]] = None) -> tvm.ir.t
return _ffi_api.DeadCodeElimination(entry_functions) # type: ignore
+def ToMixedPrecision(out_dtype="float32") -> tvm.ir.transform.Pass:
+ """Automatic mixed precision pass. Currently the pass assumes the input
module to be fp32
+ only, and will automatically cast fp32 to fp16 for certain ops.
+ Parameters
+ ----------
+ out_dtype : str
+ The output data type of gemm/conv, which is the data type of the
accumulator.
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for mixed precision.
+ """
+ return _ffi_api.ToMixedPrecision(out_dtype) # type: ignore
+
+
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index b3190ea334..32d6083e8a 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -121,6 +121,7 @@ from tvm.relax.op import (
unique,
vm,
where,
+ wrap_param,
zeros,
zeros_like,
nn,
@@ -641,6 +642,7 @@ __all__ = [
"variance",
"vm",
"where",
+ "wrap_param",
"zeros",
"zeros_like",
"nn",
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index de6eec6236..6d49bea6b6 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -121,7 +121,8 @@ TVM_REGISTER_OP("relax.image.resize2d")
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("size", "Shape", "The output image shape.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize2D)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutResize2d);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutResize2d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index f356876620..e10d205b23 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -177,13 +177,23 @@ InferLayoutOutput InferLayoutConv2d(const Call& call,
return InferLayoutOutput({data_layout, weight_layout}, {output_layout},
Attrs(new_attrs));
}
+Call InferMixedPrecisionConv2d(const Call& call, const DataType& out_dtype) {
+ const auto* conv2d_attrs = call->attrs.as<Conv2DAttrs>();
+ return Downcast<Call>(conv2d(call->args[0], call->args[1],
conv2d_attrs->strides,
+ conv2d_attrs->padding, conv2d_attrs->dilation,
conv2d_attrs->groups,
+ conv2d_attrs->data_layout,
conv2d_attrs->kernel_layout,
+ conv2d_attrs->out_layout, out_dtype));
+}
+
TVM_REGISTER_OP("relax.nn.conv2d")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_attrs_type<Conv2DAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2d)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv2d);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv2d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kAlways)
+ .set_attr<FInferMixedPrecision>("FInferMixedPrecision",
InferMixedPrecisionConv2d);
/* relax.nn.conv2d_transpose */
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 6bce51ca50..c3e18f8e3b 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -294,7 +294,8 @@ TVM_REGISTER_OP("relax.nn.layer_norm")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutLayerNorm);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutLayerNorm)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.nn.group_norm */
TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
@@ -404,7 +405,8 @@ TVM_REGISTER_OP("relax.nn.group_norm")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutGroupNorm);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutGroupNorm)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.nn.dropout */
TVM_REGISTER_NODE_TYPE(DropoutAttrs);
@@ -429,7 +431,8 @@ TVM_REGISTER_OP("relax.nn.dropout")
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input to which dropout will be applied.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDropout)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.nn.cross_entropy_with_logits */
StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder&
ctx) {
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index be0a794dee..c31ce3dd0b 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -139,7 +139,8 @@ TVM_REGISTER_OP("relax.nn.max_pool2d")
.add_argument("data", "Tensor", "The input tensor")
.set_attrs_type<Pool2DAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
Array<IntImm> dilation, bool ceil_mode, String layout,
@@ -155,7 +156,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d")
.add_argument("data", "Tensor", "The input tensor")
.set_attrs_type<Pool2DAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.nn.adaptive_avg_pool2d */
TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs);
@@ -237,7 +239,8 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor")
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoAdaptiveAvgPool2D)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout",
InferLayoutAdaptiveAvgPool2D);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout",
InferLayoutAdaptiveAvgPool2D)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index ece4c4a321..bd5f2cd4d5 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -34,6 +34,7 @@
#include <utility>
#include <vector>
+#include "../transform/infer_amp_utils.h"
#include "../transform/infer_layout_utils.h"
namespace tvm {
@@ -70,11 +71,12 @@ inline TensorStructInfo GetUnaryInputTensorStructInfo(const
Call& call, const Bl
* \param OpRegName The name of operator to register. The name passed in will
* be prepended with a prefix "relax." as the identifier string in the
operator registry.
*/
-#define RELAX_REGISTER_UNARY_OP(OpRegName) \
- TVM_REGISTER_OP("relax." OpRegName) \
- .set_num_inputs(1) \
- .add_argument("x", "Tensor", "The input tensor.") \
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
+#define RELAX_REGISTER_UNARY_OP(OpRegName)
\
+ TVM_REGISTER_OP("relax." OpRegName)
\
+ .set_num_inputs(1)
\
+ .add_argument("x", "Tensor", "The input tensor.")
\
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
\
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
/*!
* \brief Quick helper macro to expose a make-function to construct the
operator.
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index 197110c000..086e37f883 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -37,17 +37,18 @@ namespace relax {
* 1. be prepended with a prefix "relax.op." as the FFI identifier string for
the make function,
* 2. be prepended with a prefix "relax." as the identifier string in the
operator registry.
*/
-#define RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName) \
- Expr OpName(Expr x1, Expr x2) { \
- static const Op& op = Op::Get("relax." #OpName); \
- return Call(op, {x1, x2}, Attrs(), {}); \
- } \
- TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \
- TVM_REGISTER_OP("relax." #OpName) \
- .set_num_inputs(2) \
- .add_argument("x1", "Tensor", "The first input tensor.") \
- .add_argument("x2", "Tensor", "The second input tensor.") \
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutBinaryEwise)
+#define RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName)
\
+ Expr OpName(Expr x1, Expr x2) {
\
+ static const Op& op = Op::Get("relax." #OpName);
\
+ return Call(op, {x1, x2}, Attrs(), {});
\
+ }
\
+ TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName);
\
+ TVM_REGISTER_OP("relax." #OpName)
\
+ .set_num_inputs(2)
\
+ .add_argument("x1", "Tensor", "The first input tensor.")
\
+ .add_argument("x2", "Tensor", "The second input tensor.")
\
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout",
InferLayoutBinaryEwise) \
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
#define RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(OpName) \
RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr<FInferStructInfo>( \
diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc
index e8374d1981..d4e5e166b7 100644
--- a/src/relax/op/tensor/create.cc
+++ b/src/relax/op/tensor/create.cc
@@ -82,7 +82,8 @@ TVM_REGISTER_OP("relax.full")
.set_num_inputs(2)
.add_argument("shape", "Shape", "The shape of the created tensor.")
.add_argument("fill_value", "Tensor", "The scalar tensor, denoting the
value to fill.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFull);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFull)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.full_like */
Expr full_like(Expr x, Expr fill_value, DataType dtype) {
@@ -119,7 +120,8 @@ TVM_REGISTER_OP("relax.full_like")
.set_num_inputs(2)
.add_argument("x", "Tensor", "The input tensor.")
.add_argument("fill_value", "Tensor", "The scalar value to fill.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFullLike);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFullLike)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
// Structure info inference for ones and zeros
StructInfo InferStructInfoOnesZeros(const Call& call, const BlockBuilder& ctx)
{
@@ -175,7 +177,8 @@ TVM_REGISTER_OP("relax.ones")
.set_attrs_type<InitAttrs>()
.set_num_inputs(1)
.add_argument("shape", "Shape", "The shape of the created tensor.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesZeros);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesZeros)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
TVM_REGISTER_OP("relax.ones_like")
.set_attrs_type<InitAttrs>()
@@ -207,7 +210,8 @@ TVM_REGISTER_OP("relax.zeros")
.set_attrs_type<InitAttrs>()
.set_num_inputs(1)
.add_argument("shape", "Shape", "The shape of the created tensor.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesZeros);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesZeros)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
TVM_REGISTER_OP("relax.zeros_like")
.set_attrs_type<InitAttrs>()
diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc
index 349a54ee4d..18747fedcd 100644
--- a/src/relax/op/tensor/datatype.cc
+++ b/src/relax/op/tensor/datatype.cc
@@ -55,7 +55,35 @@ TVM_REGISTER_OP("relax.astype")
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAstype)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
+
+/* relax.wrap_param */
+TVM_REGISTER_NODE_TYPE(WrapParamAttrs);
+
+Expr MakeWrapParam(Expr data, DataType dtype) {
+ ObjectPtr<WrapParamAttrs> attrs = make_object<WrapParamAttrs>();
+ attrs->dtype = dtype;
+
+ static const Op& op = Op::Get("relax.wrap_param");
+ return Call(op, {std::move(data)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.wrap_param").set_body_typed(MakeWrapParam);
+
+StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx)
{
+ TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ const auto* attrs = call->attrs.as<WrapParamAttrs>();
+ ObjectPtr<TensorStructInfoNode> new_sinfo =
make_object<TensorStructInfoNode>(*sinfo.get());
+ new_sinfo->dtype = attrs->dtype;
+ return TensorStructInfo(new_sinfo);
+}
+
+TVM_REGISTER_OP("relax.wrap_param")
+ .set_attrs_type<WrapParamAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoWrapParam);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/datatype.h b/src/relax/op/tensor/datatype.h
index 6afa7a50d4..b612c45fc9 100644
--- a/src/relax/op/tensor/datatype.h
+++ b/src/relax/op/tensor/datatype.h
@@ -39,6 +39,14 @@ namespace relax {
*/
Expr astype(Expr x, DataType dtype);
+/*!
+ * \brief A wrapper to wrap the input const tensor to the given data type.
+ * \param x The input const tensor to the operator.
+ * \param dtype The target data type
+ * \return The wrapped result.
+ */
+Expr wrap_param(Expr x, DataType dtype);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 29f668ccf3..218de6e2c6 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -210,7 +210,8 @@ TVM_REGISTER_OP("relax.strided_slice")
.set_num_inputs(1)
.add_argument("x", "Tensor", "The source tensor to be sliced.")
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoStridedSlice)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/linear_algebra.cc
b/src/relax/op/tensor/linear_algebra.cc
index 50b53d0c8e..afcc7fefe7 100644
--- a/src/relax/op/tensor/linear_algebra.cc
+++ b/src/relax/op/tensor/linear_algebra.cc
@@ -113,11 +113,17 @@ StructInfo InferStructInfoMatmul(const Call& call, const
BlockBuilder& ctx) {
return TensorStructInfo(ShapeExpr(output_shape), out_dtype);
}
+Call InferMixedPrecisionMatmul(const Call& call, const DataType& out_dtype) {
+ return Downcast<Call>(matmul(call->args[0], call->args[1], out_dtype));
+}
+
TVM_REGISTER_OP("relax.matmul")
.set_num_inputs(2)
.add_argument("x1", "Tensor", "The first input tensor.")
.add_argument("x2", "Tensor", "The second input tensor.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMatmul);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMatmul)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kAlways)
+ .set_attr<FInferMixedPrecision>("FInferMixedPrecision",
InferMixedPrecisionMatmul);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 49f745608f..dbeb6f8d5b 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -107,7 +107,8 @@ TVM_REGISTER_OP("relax.broadcast_to")
.set_num_inputs(2)
.add_argument("x", "Tensor", "The input tensor.")
.add_argument("shape", "Shape", "The target shape.")
- .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoBroadcastTo);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBroadcastTo)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.concat */
TVM_REGISTER_NODE_TYPE(ConcatAttrs);
@@ -301,7 +302,8 @@ TVM_REGISTER_OP("relax.concat")
.set_num_inputs(1)
.add_argument("tensors", "Tuple of Tensors", "The input list of tensors.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConcat)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConcat);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConcat)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.expand_dims */
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
@@ -397,7 +399,8 @@ TVM_REGISTER_OP("relax.expand_dims")
.set_attrs_type<ExpandDimsAttrs>()
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoExpandDims)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutExpandDims);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutExpandDims)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
// Helper function for flatten and reshape.
PrimExpr ComputeShapeProduct(const Array<PrimExpr>& shape_values) {
@@ -437,7 +440,8 @@ StructInfo InferStructInfoFlatten(const Call& call, const
BlockBuilder& ctx) {
TVM_REGISTER_OP("relax.flatten")
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlatten);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlatten)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.layout_transform */
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
@@ -499,7 +503,8 @@ TVM_REGISTER_OP("relax.layout_transform")
.set_num_inputs(1)
.set_attrs_type<LayoutTransformAttrs>()
.add_argument("x", "Tensor", "The input tensor.")
- .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoLayoutTransform);
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoLayoutTransform)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.permute_dims */
TVM_REGISTER_NODE_TYPE(PermuteDimsAttrs);
@@ -610,7 +615,8 @@ TVM_REGISTER_OP("relax.permute_dims")
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPermuteDims)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPermuteDims);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPermuteDims)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.reshape */
Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
@@ -752,7 +758,8 @@ TVM_REGISTER_OP("relax.reshape")
.set_num_inputs(2)
.add_argument("x", "Tensor", "The input tensor.")
.add_argument("shape", "Shape", "The input new shape.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoReshape);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoReshape)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.split */
TVM_REGISTER_NODE_TYPE(SplitAttrs);
@@ -885,7 +892,8 @@ TVM_REGISTER_OP("relax.split")
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSplit)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSplit);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSplit)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
/* relax.squeeze */
TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
@@ -1040,7 +1048,8 @@ TVM_REGISTER_OP("relax.squeeze")
.set_attrs_type<SqueezeAttrs>()
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSqueeze)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSqueeze);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSqueeze)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
void CheckCollapseShape(const Call& call, const BlockBuilder& ctx,
const Array<PrimExpr>& data_shape, const
Array<PrimExpr>& target_shape) {
diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc
index 93652f43ef..940192bd8e 100644
--- a/src/relax/op/tensor/ternary.cc
+++ b/src/relax/op/tensor/ternary.cc
@@ -111,7 +111,8 @@ TVM_REGISTER_OP("relax.ewise_fma")
.add_argument("x2", "Tensor", "The right hand operand of the
multiplication")
.add_argument("x3", "Tensor", "The operand of the addition")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEwiseFMA)
- .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutEwiseFMA);
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutEwiseFMA)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
Expr ewise_fma(Expr x1, Expr x2, Expr x3) {
static const Op& op = Op::Get("relax.ewise_fma");
diff --git a/src/relax/transform/infer_amp_utils.cc
b/src/relax/transform/infer_amp_utils.cc
new file mode 100644
index 0000000000..330fe9a72a
--- /dev/null
+++ b/src/relax/transform/infer_amp_utils.cc
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "infer_amp_utils.h"
+
+namespace tvm {
+namespace relax {
+
+NType NTypeFrom(const StructInfo& sinfo, DataType dtype) {
+ auto fmapleaf = [&](const StructInfo& sinfo) -> NType {
+ const auto* tensor = sinfo.as<TensorStructInfoNode>();
+ ICHECK(tensor) << "Expected TensorStructInfo, but got " << sinfo;
+ if (dtype == DataType::Void())
+ return NType(DLDataType2String(tensor->dtype));
+ else
+ return NType(DLDataType2String(dtype));
+ };
+ return MapToNestedMsg<String>(sinfo, fmapleaf);
+}
+
+NType NTypeFrom(const Expr& expr, DataType dtype) { return
NTypeFrom(GetStructInfo(expr), dtype); }
+
+NType NTypeMerge(const NType& a, const NType& b) {
+ auto fcombine = [&](const String& a_str, const String& b_str) -> String {
+ DataType a = DataType(String2DLDataType(a_str));
+ DataType b = DataType(String2DLDataType(b_str));
+ ICHECK_EQ(a.code(), b.code());
+ ICHECK_EQ(a.lanes(), b.lanes());
+ return a.bits() > b.bits() ? a_str : b_str;
+ };
+ return CombineNestedMsg<String>(a, b, fcombine);
+}
+
+Array<ObjectRef> InferMixedPrecisionFollow(const Call& call, const DataType&
out_dtype) {
+ return {Integer(MixedPrecisionPolicyKind::kFollow), call};
+}
+
+Array<ObjectRef> InferMixedPrecisionNever(const Call& call, const DataType&
out_dtype) {
+ return {Integer(MixedPrecisionPolicyKind::kNever), call};
+}
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/infer_amp_utils.h
b/src/relax/transform/infer_amp_utils.h
new file mode 100644
index 0000000000..3c98af6db9
--- /dev/null
+++ b/src/relax/transform/infer_amp_utils.h
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file infer_amp_utils.h
+ * \brief Utility functions to be used in to_mixed_precision pass.
+ */
+
+#ifndef TVM_RELAX_TRANSFORM_INFER_AMP_UTILS_H_
+#define TVM_RELAX_TRANSFORM_INFER_AMP_UTILS_H_
+
+#include <tvm/relax/attrs/nn.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/tir/data_layout.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+using runtime::DLDataType2String;
+using runtime::String;
+using runtime::String2DLDataType;
+
+enum MixedPrecisionPolicyKind : int { kAlways = 0, kFollow = 1, kNever = 2 };
+
+/*! \brief the operator pattern */
+using TMixedPrecisionPolicy = int;
+
+// NType is the message we want to track for vars with nested tensorstructinfo
+// which represents the realization decision of the var.
+// The string is the name of the dtype decision.
+using NType = NestedMsg<String>;
+
+struct NTypeEqual {
+ bool operator()(const NType& a, const NType& b) const {
+ auto dtype_equal = [](const String& a, const String& b) { return a == b; };
+ return Equal(a, b, dtype_equal);
+ }
+};
+
+// Construct a NType from an StructInfo
+NType NTypeFrom(const StructInfo& sinfo, DataType dtype = DataType::Void());
+
+// Construct a NType from an Expr
+NType NTypeFrom(const Expr& expr, DataType dtype = DataType::Void());
+
+// Merge two messages, we keep the higher precision type for each leaf tensor
+NType NTypeMerge(const NType& a, const NType& b);
+
+// The map that notes the NType message of each var
+using VarDTypeMap = std::unordered_map<Var, NType, ObjectPtrHash,
ObjectPtrEqual>;
+
+// Call is a call node, out_dtype is the expected output_dtype
+using FInferMixedPrecision =
+ runtime::TypedPackedFunc<Call(const Call& call_node, const DataType&
out_dtype)>;
+
+Array<ObjectRef> InferMixedPrecisionFollow(const Call& call, const DataType&
out_dtype);
+
+Array<ObjectRef> InferMixedPrecisionNever(const Call& call, const DataType&
out_dtype);
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_TRANSFORM_INFER_AMP_UTILS_H_
diff --git a/src/relax/transform/to_mixed_precision.cc
b/src/relax/transform/to_mixed_precision.cc
new file mode 100644
index 0000000000..4728f81b63
--- /dev/null
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -0,0 +1,538 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/to_mixed_precision.cc
+ * \brief Automatic mixed precision pass.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/relax/transform.h>
+
+#include <array>
+
+#include "../op/nn/convolution.h"
+#include "../op/tensor/datatype.h"
+#include "../op/tensor/linear_algebra.h"
+#include "infer_amp_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+using runtime::String;
+
+int GetMixedPrecisionInfo(const CallNode* call_node) {
+ const OpNode* op_node = call_node->op.as<OpNode>();
+ if (op_node == nullptr) {
+ return -1;
+ }
+ Op op = GetRef<Op>(op_node);
+ auto attr_map =
Op::GetAttrMap<TMixedPrecisionPolicy>("TMixedPrecisionPolicy");
+ return attr_map.count(op) ? attr_map[op] : MixedPrecisionPolicyKind::kNever;
+}
+
+/*!
+ * \brief Main logic to automatically cast fp32 input modules to fp16 for
certain ops.
+ *
+ * Structurally speaking, a Relax function is composed of a series of
VarBinding and
+ * MatchCast. And a specific class of VarBindings is the basic unit we want to
rewrite.
+ * Formally, they are of the form:
+ *
+ * var = Call(Op, [args], attrs)
+ *
+ * where Op is a specific op we want to rewrite, and attrs is the attributes
of the op.
+ * var and args are all exprs with type Tensor or Tuple of Tensors. They might
+ * be vars, constants, or Tuple of vars and constants.
+ * Depending on the properties of the op, we may have 3 different ways to
rewrite it:
+ *
+ * 1. kAlways: Always cast the args to fp16
+ * Currently, this is only used for gemm and conv ops (to favor the use of
TensorCore)
+ * We always cast the input args to fp16, and the dtype of the accumulator
is configured
+ * by the global output_dtype parameter (default to fp32). We cast the
output to fp16.
+ *
+ * 2. kFollow: If any of the args if fp32, cast all args to fp32. Otherwise,
use fp16.
+ *
+ * 3. kNever: Never cast the args to fp16. Always cast all args to fp32 (the
original dtype).
+ * Some ops, such as softmax, have numerical issues when using fp16. We
will always use fp32
+ * to ensure the correctness.
+ *
+ * Note that in this case, we will actively cast the arg to fp16 only when
it's used in kAlways.
+ * This is to ensure that we have numerical stability to the best effort.
+ *
+ * DTypeDecisionCollector:
+ * Note that if some tensor is only used in kAlways ops, we can store it in
fp16 without worsening
+ * numerical stability or using more storage. We use a backward propagation
pass to detect such
+ * tensors. We will store the information of each var in the only_fp16_map_.
+ *
+ * We reuse the NTtype struct to store the information of each var. There
are 3 kinds of info:
+ * - Unknown (Float0): we never encounter a use of this tensor
+ * - Float16: we only encounter uses of this tensor in kAlways ops
+ * - Float32: we encounter some use of this tensor outside of kAlways ops
+ * The info value forms a semi-lattice, where Float8 is the top, Float16 is
the middle, and
+ * Float32 is the bottom. The lower bound of two info values is the one with
more bits.
+ *
+ * ToMixedPrecisionRewriter:
+ * We will then use a forward propagation pass to rewrite the program. Since
we only keep one
+ * specific data type for each var, and we will cast the var to the required
dtype locally when we
+ * encounter its use if needed. Note that we may cast the var to some
certain dtype multiple
+ * times, but we decide not to store and reuse the casted copy due to the
storage concern and to
+ * be more friendly to inlining and operator fusion. We will store the var
to fp16 if it's only
+ * used in kAlways ops, otherwise we will store it as the natural output
dtype of the op.
+ *
+ * The information of each op is registered in the
+ * Op::GetAttr<FInferMixedPrecision>("FInferMixedPrecision"). The registered
function has signature:
+ * FInferMixedPrecision. We will call the registered function with the
original call and the global
+ * output_dtype parameter. The registered function will return the policy of
the op, whether the op
+ * can adjust the dtype of the accumulator, and the new call node with
output_dtype set to the
+ * global output_dtype parameter.
+ *
+ * Key design: wrap_param op
+ * We need to use fp16 parameters (which appear as constants in the
program), but the type
+ * inference will fail if some parameters are fp16 and some are fp32 in the
original module. To
+ * solve this, we introduce a new op wrap_param, which will wrap the
original parameter and cast
+ * it to fp32 var.
+ *
+ * When we encounter the var afterwards, we will directly replace it with
the parameter. This
+ * information is tracked by the const_map_.
+ */
+class DTypeDecisionCollector : public ExprVisitor {
+ public:
+ explicit DTypeDecisionCollector(DataType output_dtype) :
output_dtype_(output_dtype) {}
+
+ static VarDTypeMap Collect(Function func, DataType output_dtype) {
+ DTypeDecisionCollector collector(output_dtype);
+ collector.VisitExpr(func);
+ return std::move(collector.only_fp16_map_);
+ }
+
+ private:
+ NType GetDType(const Var& var) {
+ auto it = only_fp16_map_.find(var);
+ if (it == only_fp16_map_.end()) {
+ // we never encounter this var before
+ NType unknown = NTypeFrom(var, unknown_);
+ only_fp16_map_[var] = unknown;
+ return unknown;
+ }
+ return it->second;
+ }
+
+ // merge the message for a var
+ void UpdateVarDTypeMap(const Var& var, const NType& dtype) {
+ auto it = only_fp16_map_.find(var);
+ if (it == only_fp16_map_.end()) {
+ only_fp16_map_[var] = dtype;
+ } else {
+ only_fp16_map_[var] = NTypeMerge(it->second, dtype);
+ }
+ }
+
+ // merge the message for all vars in the expr list
+ void RequireArgsToType(Array<Expr> args, Array<NType> to) {
+ ICHECK(args.size() == to.size()) << "Invalid target dtypes";
+ for (size_t i = 0; i < args.size(); ++i) {
+ auto fvisitleaf = [&](const Expr& expr, NType to) {
+ if (const auto* var = expr.as<VarNode>()) {
+ UpdateVarDTypeMap(GetRef<Var>(var), to);
+ } else if (expr->IsInstance<ConstantNode>()) {
+ // Constant can be casted anyway, so we don't need to do anything
here
+ return;
+ } else {
+ LOG(FATAL) << "Unsupported argument type: " << expr->GetTypeKey();
+ }
+ };
+ DecomposeNestedMsg(args[i], to[i], fvisitleaf);
+ }
+ }
+
+ // merge the message for all vars in the expr list
+ void RequireArgsToType(Array<Expr> args, DataType to) {
+ std::vector<Expr> arg_arr;
+ std::vector<NType> to_arr;
+ for (const Expr& arg : args) {
+ if (IsNestedTensor(arg)) {
+ // only require the nested tensor args
+ arg_arr.push_back(arg);
+ to_arr.push_back(NTypeFrom(arg, to));
+ }
+ }
+ RequireArgsToType(std::move(arg_arr), std::move(to_arr));
+ }
+
+ void VisitVars_(const VarNode* op) {
+ Var var = GetRef<Var>(op);
+ if (IsNestedTensor(var)) {
+ // require the var to be fp32 (its original dtype)
+ UpdateVarDTypeMap(var, NTypeFrom(var, fp32_));
+ return;
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const VarNode* op) final { VisitVars_(op); }
+
+ void VisitExpr_(const DataflowVarNode* op) final { VisitVars_(op); }
+
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node)
final {
+ auto policy = GetMixedPrecisionInfo(call_node);
+ if (policy == -1) {
+ ExprVisitor::VisitBinding_(binding, call_node);
+ return;
+ }
+ if (policy == kAlways) {
+ // require inputs to be fp16
+ RequireArgsToType(call_node->args, fp16_);
+ } else if (policy == kFollow || policy == kNever) {
+ // require inputs to be fp32 (the original dtype)
+ RequireArgsToType(call_node->args, fp32_);
+ } else {
+ LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy;
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleNode*
tuple_node) final {
+ // require input fields to be the type of the lhs field respectively
+ NType lhs_type = GetDType(binding->var);
+ RequireArgsToType(tuple_node->fields, lhs_type.NestedArray());
+ }
+
+ void VisitBinding_(const VarBindingNode* binding,
+ const TupleGetItemNode* tuple_get_item_node) final {
+ // require the i-th field rhs tuple to be the type of the lhs
+ NType lhs_type = GetDType(binding->var);
+ std::vector<NType> require_rhs;
+ const TupleStructInfoNode* sinfo =
+ tuple_get_item_node->tuple->struct_info_.as<TupleStructInfoNode>();
+ ICHECK(sinfo != nullptr) << "TupleGetItemNode must have TupleStructInfo";
+ for (size_t i = 0; i < sinfo->fields.size(); ++i) {
+ if (i == (size_t)tuple_get_item_node->index) {
+ require_rhs.push_back(lhs_type);
+ } else {
+ require_rhs.push_back(NTypeFrom(sinfo->fields[i], unknown_));
+ }
+ }
+ RequireArgsToType({tuple_get_item_node->tuple}, {NType(require_rhs)});
+ }
+
+ // override the following methods to visit in backward order
+ void VisitExpr_(const SeqExprNode* op) final {
+ this->VisitSpan(op->span);
+ this->VisitExpr(op->body);
+ for (auto it = op->blocks.rbegin(); it != op->blocks.rend(); it++) {
+ this->VisitBindingBlock(*it);
+ }
+
+ if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) {
+ this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
+ }
+ }
+
+ void VisitBindingBlock_(const BindingBlockNode* block) { return; }
+
+ void VisitBindingBlock_(const DataflowBlockNode* block) {
+ for (auto it = block->bindings.rbegin(); it != block->bindings.rend();
it++) {
+ this->VisitBinding(*it);
+ }
+ }
+
+ void VisitExpr_(const IfNode* op) final {
+ this->VisitSpan(op->span);
+ this->VisitExpr(op->true_branch);
+ this->VisitExpr(op->false_branch);
+ this->VisitExpr(op->cond);
+
+ if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) {
+ this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
+ }
+ }
+
+ DataType unknown_ = DataType(DataType::TypeCode::kFloat, 0, 1);
+ DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1);
+ DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1);
+ DataType output_dtype_;
+ VarDTypeMap only_fp16_map_;
+};
+
+class ToMixedPrecisionRewriter : public ExprMutator {
+ public:
+ explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType
output_dtype)
+ : only_fp16_map_(only_fp16_map), output_dtype_(output_dtype) {}
+
+ private:
+ Var GetRemapped(const Var& var) {
+ auto it = var_remap_.find(var->vid);
+ return it == var_remap_.end() ? var : it->second;
+ }
+
+ Array<Expr> RemapArgs(const Array<Expr>& args) {
+ Array<Expr> new_args;
+ for (const auto& arg : args) {
+ new_args.push_back(VarReplacer::Replace(arg, var_remap_));
+ }
+ return new_args;
+ }
+
+ // Util function to rewrite the expr to the given dtype
+ // rewrite each leaf tensor to the given dtype if necessary
+ // Note that this function only accepts expr with nested tensor type
+ Expr RewriteExpr(const Expr& expr, const NType& to) {
+ auto fvisitleaf = [&](const Expr& expr, std::array<NType, 1> to) -> Expr {
+ const auto* tensor = GetStructInfoAs<TensorStructInfoNode>(expr);
+ ICHECK(tensor != nullptr) << "Only support rewriting tensor expr";
+ // We only rewrite the expr if the dtype is not the same as the given
dtype
+ if (NTypeEqual()(to[0], NTypeFrom(expr))) return expr;
+ // We only rewrite the expr if the dtype is fp16 or fp32, dtypes such as
int32, float64 is not
+ // supported to be rewritten
+ if (tensor->dtype != fp16_ && tensor->dtype != fp32_) return expr;
+ return astype(expr, DataType(String2DLDataType(to[0].LeafValue())));
+ };
+ return TransformTupleLeaf<String>(expr, std::array<NType, 1>({to}),
fvisitleaf);
+ }
+
+ Array<Expr> RewriteArgs(const Array<Expr>& args, DataType to) {
+ Array<Expr> new_args;
+ for (const Expr& arg : args) {
+ if (IsNestedTensor(arg)) {
+ new_args.push_back(RewriteExpr(arg, NTypeFrom(arg, to)));
+ } else {
+ new_args.push_back(arg);
+ }
+ }
+ return new_args;
+ }
+
+ // Util function to check if any of the tensors in the args is fp32
+ bool AnyArgIsFP32(const NType& cur_type) {
+ bool result = false;
+ auto fvisitleaf = [&, this](const String& dtype) {
+ if (dtype == "float32") {
+ result = true;
+ }
+ };
+ ForEachLeaf<String>(cur_type, fvisitleaf);
+ return result;
+ }
+
+ bool AnyArgIsFP32(const Array<Expr>& args) {
+ for (const Expr& arg : args) {
+ if (IsNestedTensor(arg)) {
+ if (AnyArgIsFP32(NTypeFrom(arg))) return true;
+ }
+ }
+ return false;
+ }
+
+ void CastIfFp16Only(const Var& var) {
+ ICHECK(builder_->CurrentBlockIsDataFlow());
+ // Get the current remapped var
+ Var cur_var = GetRemapped(var);
+ // Store the tensors that are fp16 only to fp16
+ auto it = only_fp16_map_->find(var);
+ if (it == only_fp16_map_->end()) return;
+ // Get the to dtype, cast to fp16 if the var is fp16 only, otherwise do
nothing
+ auto fcombine = [](const String& from, const String& required) -> String {
+ return required == "float16" ? required : from;
+ };
+ NType from = NTypeFrom(cur_var);
+ NType to = CombineNestedMsg<String>(from, it->second, fcombine);
+ Expr rewrite = RewriteExpr(cur_var, to);
+ // If cur_var is not rewritten, we don't need to emit a new var
+ if (!rewrite.same_as(cur_var)) {
+ // Emit a new var, and update the var remap
+ var_remap_[var->vid] = builder_->Emit(rewrite);
+ }
+ }
+
+ Expr VisitVar_(const Var& var) {
+ // We rewrite the remapped var to the original dtype
+ auto it = var_remap_.find(var->vid);
+ if (it != var_remap_.end()) {
+ return RewriteExpr(it->second, NTypeFrom(var));
+ }
+ return var;
+ }
+
+ Expr VisitExpr_(const VarNode* op) final {
+ if (!builder_->CurrentBlockIsDataFlow()) {
+ return ExprMutator::VisitExpr_(op);
+ }
+ return VisitVar_(GetRef<Var>(op));
+ }
+
+ Expr VisitExpr_(const DataflowVarNode* op) final {
+ if (!builder_->CurrentBlockIsDataFlow()) {
+ return ExprMutator::VisitExpr_(op);
+ }
+ return VisitVar_(GetRef<Var>(op));
+ }
+
+ void VisitBinding(const Binding& binding) {
+ ExprMutator::VisitBinding(binding);
+ if (!builder_->CurrentBlockIsDataFlow()) return;
+ CastIfFp16Only(binding->var);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node)
final {
+ if (!builder_->CurrentBlockIsDataFlow()) {
+ ExprMutator::VisitBinding_(binding, call_node);
+ return;
+ }
+ auto policy = GetMixedPrecisionInfo(call_node);
+ if (policy == -1) {
+ // not an op call
+ ExprMutator::VisitBinding_(binding, call_node);
+ return;
+ }
+ // var = Call(op)
+ const auto* op_node = call_node->op.as<OpNode>();
+ ICHECK(op_node != nullptr);
+ Op op = GetRef<Op>(op_node);
+ if (wrap_param_op.same_as(op)) {
+ // wrap_param
+ ReEmitBinding(binding, call_node->args[0]);
+ return;
+ }
+ DataType to;
+ ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node);
+ // We first to remap the args to the current vars according to the
var_remap_
+ new_call->args = std::move(RemapArgs(call_node->args));
+ // Then we rewrite the args according to the policy
+ if (policy == kAlways) {
+ to = fp16_;
+ auto attr_map =
Op::GetAttrMap<FInferMixedPrecision>("FInferMixedPrecision");
+ ICHECK(attr_map.count(op));
+ auto f = attr_map[op];
+ new_call = make_object<CallNode>(*(f(Call(new_call),
output_dtype_).get()));
+ } else if (policy == kFollow) {
+ to = AnyArgIsFP32(new_call->args) ? fp32_ : fp16_;
+ } else if (policy == kNever) {
+ to = fp32_;
+ } else {
+ LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy;
+ }
+ new_call->args = std::move(RewriteArgs(new_call->args, to));
+ new_call->struct_info_ = NullOpt;
+ Expr new_value = builder_->Normalize(Call(new_call));
+ if (policy == kAlways && binding->var->IsInstance<DataflowVarNode>()) {
+ // kAlways: store the tensors to fp16
+ // But global vars will be stored to the original dtype anyway (see
below)
+ new_value = RewriteExpr(new_value, NTypeFrom(new_value, fp16_));
+ }
+ if (!binding->var->IsInstance<DataflowVarNode>()) {
+ // Global var: store the tensors to the original dtype
+ NType to = NTypeFrom(binding->var);
+ new_value = RewriteExpr(new_value, to);
+ }
+ ReEmitBinding(binding, builder_->Normalize(new_value));
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleNode*
tuple_node) final {
+ if (!builder_->CurrentBlockIsDataFlow()) {
+ ExprMutator::VisitBinding_(binding, tuple_node);
+ return;
+ }
+ ObjectPtr<TupleNode> new_tuple = make_object<TupleNode>(*tuple_node);
+ new_tuple->fields = std::move(RemapArgs(tuple_node->fields));
+ new_tuple->struct_info_ = NullOpt;
+ Expr new_value = builder_->Normalize(Tuple(new_tuple));
+ if (!binding->var->IsInstance<DataflowVarNode>()) {
+ // Global var: store the tensors to the original dtype
+ NType to = NTypeFrom(binding->var);
+ new_value = RewriteExpr(new_value, to);
+ }
+ ReEmitBinding(binding, builder_->Normalize(new_value));
+ }
+
+ void VisitBinding_(const VarBindingNode* binding,
+ const TupleGetItemNode* tuple_get_item_node) final {
+ if (!builder_->CurrentBlockIsDataFlow()) {
+ // We don't need to rewrite the tuple_get_item in dataflow block
+ ExprMutator::VisitBinding_(binding, tuple_get_item_node);
+ return;
+ }
+ ObjectPtr<TupleGetItemNode> new_tuple_get_item =
+ make_object<TupleGetItemNode>(*tuple_get_item_node);
+ new_tuple_get_item->tuple = RemapArgs({tuple_get_item_node->tuple})[0];
+ new_tuple_get_item->struct_info_ = NullOpt;
+ Expr new_value = TupleGetItem(new_tuple_get_item);
+ if (!binding->var->IsInstance<DataflowVarNode>()) {
+ // Global var: store the tensors to the original dtype
+ NType to = NTypeFrom(binding->var);
+ new_value = RewriteExpr(new_value, to);
+ }
+ ReEmitBinding(binding, builder_->Normalize(new_value));
+ }
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) {
+ builder_->BeginDataflowBlock();
+ // prepare local versions of params here, if they are fp16 expected only
+ for (auto param : params_) {
+ CastIfFp16Only(param);
+ }
+ for (auto binding : block->bindings) {
+ this->VisitBinding(binding);
+ }
+ for (auto param : params_) {
+ // remove the local version of params
+ auto it = var_remap_.find(param->vid);
+ if (it != var_remap_.end()) {
+ var_remap_.erase(it);
+ }
+ }
+ return builder_->EndBlock();
+ }
+
+ Expr VisitExpr_(const FunctionNode* op) final {
+ params_ = op->params;
+ return ExprMutator::VisitExpr_(op);
+ }
+
+ const VarDTypeMap* only_fp16_map_;
+
+ DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1);
+ DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1);
+ DataType output_dtype_;
+ Array<Var> params_;
+
+ const Op& wrap_param_op = Op::Get("relax.wrap_param");
+};
+
+Expr ToMixedPrecision(const Function& f, const DataType& out_dtype) {
+ VarDTypeMap only_fp16_map = std::move(DTypeDecisionCollector::Collect(f,
out_dtype));
+ ToMixedPrecisionRewriter mutator(&only_fp16_map, out_dtype);
+ return mutator(f);
+}
+
+namespace transform {
+
+Pass ToMixedPrecision(const DataType& out_dtype) {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(ToMixedPrecision(f, out_dtype));
+ };
+ return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ToMixedPrecision").set_body_typed(ToMixedPrecision);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_op_datatype.py
b/tests/python/relax/test_op_datatype.py
index 56bbe464cf..48820b9e2e 100644
--- a/tests/python/relax/test_op_datatype.py
+++ b/tests/python/relax/test_op_datatype.py
@@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import numpy as np # type: ignore
+
+
import pytest
import tvm
import tvm.testing
@@ -25,7 +28,9 @@ from tvm.script import relax as R
def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3), "float32"))
+ c = relax.Constant(tvm.nd.array(np.array([1, 2, 3], dtype="float16")))
assert relax.op.astype(x, "float16").op == Op.get("relax.astype")
+ assert relax.op.wrap_param(c, "float32").op == Op.get("relax.wrap_param")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
@@ -101,5 +106,17 @@ def test_astype_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.astype(x1, "float16"))
+def test_wrap_param_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Constant(tvm.nd.array(np.zeros([1, 2, 3], dtype="float16")))
+ x1 = relax.Constant(tvm.nd.array(np.zeros([1, 2, 3], dtype="int8")))
+ _check_inference(
+ bb, relax.op.wrap_param(x0, "float32"), relax.TensorStructInfo((1, 2,
3), "float32")
+ )
+ _check_inference(
+ bb, relax.op.wrap_param(x1, "int32"), relax.TensorStructInfo((1, 2,
3), "int32")
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py
b/tests/python/relax/test_transform_to_mixed_precision.py
new file mode 100644
index 0000000000..b9409bff52
--- /dev/null
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -0,0 +1,540 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import tvm
+from tvm import relax
+import tvm.testing
+from tvm.relax.transform import ToMixedPrecision
+from tvm.script.parser import ir as I, relax as R
+
+
+def _assert_test(input, expected):
+ mod = ToMixedPrecision()(input)
+ tvm.ir.assert_structural_equal(mod, expected)
+
+
+def test_conv2d():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ R.output(gv)
+ return gv
+
+ _assert_test(Input, Expected)
+
+
+def test_conv2d_relu():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(lv)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ lv_1: R.Tensor((2, 4, 26, 26), dtype="float16") =
R.astype(lv2, dtype="float16")
+ lv3: R.Tensor((2, 4, 26, 26), dtype="float16") =
R.nn.relu(lv_1)
+ gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv3,
dtype="float32")
+ R.output(gv)
+ return gv
+
+ _assert_test(Input, Expected)
+
+
+def test_relu_conv2d_relu():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ x0: R.Tensor((2, 3, 28, 28), "float32") = R.nn.relu(x)
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x)
+ lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x0,
dtype="float16")
+ lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+ lv1,
+ lv,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv2,
dtype="float16")
+ lv3: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv)
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv3,
dtype="float32")
+ R.output(gv2)
+ return gv2
+
+ _assert_test(Input, Expected)
+
+
+def test_conv2d_relu_conv2d():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"),
+ w: R.Tensor((4, 3, 3, 3), "float32"),
+ w2: R.Tensor((4, 4, 3, 3), "float32"),
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+ gv3: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv2,
w2, out_dtype="float32")
+ R.output(gv3)
+ return gv3
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ w2: R.Tensor((4, 4, 3, 3), dtype="float32"),
+ ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w2,
dtype="float16")
+ lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv3,
dtype="float16")
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv)
+ gv3: R.Tensor((2, 4, 24, 24), dtype="float32") = R.nn.conv2d(
+ gv2,
+ lv2,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ R.output(gv3)
+ return gv3
+
+ _assert_test(Input, Expected)
+
+
+def test_gemm_add_silu():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 320), "float32"),
+ w1: R.Tensor((320, 1280), "float32"),
+ w2: R.Tensor((2, 1280), "float32"),
+ ) -> R.Tensor(None, "float32", ndim=2):
+ with R.dataflow():
+ gv0: R.Tensor((2, 1280), "float32") = R.matmul(x, w1,
out_dtype="float32")
+ gv1: R.Tensor((2, 1280), "float32") = R.add(gv0, w2)
+ gv2: R.Tensor((2, 1280), "float32") = R.nn.silu(gv1)
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 320), dtype="float32"),
+ w1: R.Tensor((320, 1280), dtype="float32"),
+ w2: R.Tensor((2, 1280), dtype="float32"),
+ ) -> R.Tensor((2, 1280), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 320), dtype="float16") = R.astype(x,
dtype="float16")
+ lv1: R.Tensor((320, 1280), dtype="float16") = R.astype(w1,
dtype="float16")
+ lv2: R.Tensor((2, 1280), dtype="float32") = R.matmul(lv, lv1,
out_dtype="float32")
+ gv0: R.Tensor((2, 1280), dtype="float16") = R.astype(lv2,
dtype="float16")
+ lv3: R.Tensor((2, 1280), dtype="float32") = R.astype(gv0,
dtype="float32")
+ gv1: R.Tensor((2, 1280), dtype="float32") = R.add(lv3, w2)
+ gv2: R.Tensor((2, 1280), dtype="float32") = R.nn.silu(gv1)
+ R.output(gv2)
+ return gv2
+
+ _assert_test(Input, Expected)
+
+
+def test_tuple():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"),
+ w: R.Tensor((4, 3, 3, 3), "float32"),
+ w_2: R.Tensor((4, 4, 3, 3), "float32"),
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv3 = (gv, gv2)
+ gv4 = (gv3, gv2)
+ gv5 = gv4[0]
+ gv6 = gv5[0]
+ gv7 = R.nn.conv2d(gv6, w_2, out_dtype="float32")
+ R.output(gv7)
+ return gv7
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ w_2: R.Tensor((4, 4, 3, 3), dtype="float32"),
+ ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w_2,
dtype="float16")
+ lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv3,
dtype="float16")
+ lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv4,
dtype="float16")
+ gv3: R.Tuple(
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ ) = (gv, gv2)
+ gv4: R.Tuple(
+ R.Tuple(
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ ),
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ ) = (gv3, gv2)
+ gv5: R.Tuple(
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ ) = gv4[0]
+ gv6: R.Tensor((2, 4, 26, 26), dtype="float16") = gv5[0]
+ gv7: R.Tensor((2, 4, 24, 24), dtype="float32") = R.nn.conv2d(
+ gv6,
+ lv2,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ R.output(gv7)
+ return gv7
+
+ _assert_test(Input, Expected)
+
+
+def test_concat_matmul():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ lv10: R.Tensor((2, 160), "float32"),
+ lv12: R.Tensor((2, 160), "float32"),
+ w: R.Tensor((320, 1280), "float32"),
+ ) -> R.Tensor(None, "float32", ndim=2):
+ with R.dataflow():
+ lv13: R.Tensor((2, 320), "float32") = R.concat((lv10, lv12),
axis=-1)
+ lv14: R.Tensor((2, 1280), "float32") = R.matmul(lv13, w,
out_dtype="float32")
+ R.output(lv14)
+ return lv14
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ lv10: R.Tensor((2, 160), dtype="float32"),
+ lv12: R.Tensor((2, 160), dtype="float32"),
+ w: R.Tensor((320, 1280), dtype="float32"),
+ ) -> R.Tensor((2, 1280), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((320, 1280), dtype="float16") = R.astype(w,
dtype="float16")
+ lv13: R.Tensor((2, 320), dtype="float32") = R.concat((lv10,
lv12), axis=-1)
+ lv1: R.Tensor((2, 320), dtype="float16") = R.astype(lv13,
dtype="float16")
+ lv14: R.Tensor((2, 1280), dtype="float32") = R.matmul(lv1, lv,
out_dtype="float32")
+ R.output(lv14)
+ return lv14
+
+ _assert_test(Input, Expected)
+
+
+def test_conv2d_softmax():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 3, 26, 26), "float32") = R.nn.conv2d(x, w,
padding=(1, 1))
+ gv1: R.Tensor((2, 3, 26, 26), "float32") = R.nn.softmax(x,
axis=1)
+ gv2 = R.add(gv, gv1)
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor((2, 3, 26, 26), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((3, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
+ lv2: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.conv2d(
+ lv1,
+ lv,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ gv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(lv2,
dtype="float16")
+ gv1: R.Tensor((2, 3, 28, 28), dtype="float32") =
R.nn.softmax(x, axis=1)
+ lv3: R.Tensor((2, 3, 28, 28), dtype="float32") = R.astype(gv,
dtype="float32")
+ gv2: R.Tensor((2, 3, 28, 28), dtype="float32") = R.add(lv3,
gv1)
+ R.output(gv2)
+ return gv2
+
+ _assert_test(Input, Expected)
+
+
+def test_conv2d_bias_conv2d():
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def main(
+ z: R.Tensor((1, 4, 64, 64), dtype="float32"),
+ w0: R.Tensor((512, 4, 3, 3), dtype="float16"),
+ w1: R.Tensor((512,), dtype="float16"),
+ w2: R.Tensor((4, 4, 1, 1), dtype="float16"),
+ w3: R.Tensor((4,), dtype="float16"),
+ ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((512, 4, 3, 3), dtype="float32") =
R.wrap_param(w0, dtype="float32")
+ lv1: R.Tensor((512,), dtype="float32") = R.wrap_param(w1,
dtype="float32")
+ lv140: R.Tensor((4, 4, 1, 1), dtype="float32") =
R.wrap_param(w2, dtype="float32")
+ lv141: R.Tensor((4,), dtype="float32") = R.wrap_param(w3,
dtype="float32")
+ lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d(
+ z,
+ lv140,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ lv143: R.Tensor((1, 4, 1, 1), dtype="float32") =
R.reshape(lv141, (1, 4, 1, 1))
+ lv144: R.Tensor((1, 4, 64, 64), dtype="float32") =
R.add(lv142, lv143)
+ lv145: R.Tensor((1, 512, 64, 64), dtype="float32") =
R.nn.conv2d(
+ lv144,
+ lv,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ lv146: R.Tensor((1, 512, 1, 1), dtype="float32") =
R.reshape(lv1, (1, 512, 1, 1))
+ lv147: R.Tensor((1, 512, 64, 64), dtype="float32") =
R.add(lv145, lv146)
+ gv: R.Tensor((1, 512, 64, 64), dtype="float32") = lv147
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ z: R.Tensor((1, 4, 64, 64), dtype="float32"),
+ w0: R.Tensor((512, 4, 3, 3), dtype="float16"),
+ w1: R.Tensor((512,), dtype="float16"),
+ w2: R.Tensor((4, 4, 1, 1), dtype="float16"),
+ w3: R.Tensor((4,), dtype="float16"),
+ ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 4, 64, 64), dtype="float16") = R.astype(z,
dtype="float16")
+ lv_1: R.Tensor((512, 4, 3, 3), dtype="float16") = w0
+ lv1: R.Tensor((512,), dtype="float16") = w1
+ lv140: R.Tensor((4, 4, 1, 1), dtype="float16") = w2
+ lv141: R.Tensor((4,), dtype="float16") = w3
+ lv1_1: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv140,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ lv142: R.Tensor((1, 4, 64, 64), dtype="float16") =
R.astype(lv1_1, dtype="float16")
+ lv143: R.Tensor((1, 4, 1, 1), dtype="float16") =
R.reshape(lv141, (1, 4, 1, 1))
+ lv144: R.Tensor((1, 4, 64, 64), dtype="float16") =
R.add(lv142, lv143)
+ lv2: R.Tensor((1, 512, 64, 64), dtype="float32") = R.nn.conv2d(
+ lv144,
+ lv_1,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ lv145: R.Tensor((1, 512, 64, 64), dtype="float16") =
R.astype(lv2, dtype="float16")
+ lv146: R.Tensor((1, 512, 1, 1), dtype="float16") =
R.reshape(lv1, (1, 512, 1, 1))
+ lv147: R.Tensor((1, 512, 64, 64), dtype="float16") =
R.add(lv145, lv146)
+ gv: R.Tensor((1, 512, 64, 64), dtype="float32") =
R.astype(lv147, dtype="float32")
+ R.output(gv)
+ return gv
+
+ binding = {
+ "w0": np.random.uniform(size=(512, 4, 3, 3)).astype("float16"),
+ "w1": np.random.uniform(size=(512,)).astype("float16"),
+ "w2": np.random.uniform(size=(4, 4, 1, 1)).astype("float16"),
+ "w3": np.random.uniform(size=(4,)).astype("float16"),
+ }
+ binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+ Input = relax.transform.BindParams("main", binding)(Input)
+ Expected = relax.transform.BindParams("main", binding)(Expected)
+ _assert_test(Input, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()