This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new fec44f605e [Unity][Transform] Automatic Layout Conversion (#14257)
fec44f605e is described below
commit fec44f605edae8d96b03a571cf1a786843cc6c82
Author: Bohan Hou <[email protected]>
AuthorDate: Sun Mar 19 10:55:57 2023 -0400
[Unity][Transform] Automatic Layout Conversion (#14257)
This PR adds a new pass ConvertLayout that converts the layout of conv2d
operators (NCHW->NHWC) and tries to propagate this conversion when appropriate.
see
https://github.com/spectrometerHBH/tvm/blob/convert-layout/src/relax/transform/convert_layout.cc#L39
on how this pass works
It works on the op level, which is in parallel with the ongoing Layout
Transformation effort that works on TIR level in the community.
---
include/tvm/relax/nested_msg.h | 44 +
include/tvm/relax/transform.h | 9 +
python/tvm/relax/transform/transform.py | 17 +
src/relax/op/image/resize.cc | 16 +-
src/relax/op/nn/convolution.cc | 45 +-
src/relax/op/nn/nn.cc | 101 +-
src/relax/op/nn/pooling.cc | 43 +-
src/relax/op/op_common.cc | 8 +
src/relax/op/op_common.h | 22 +-
src/relax/op/tensor/binary.cc | 22 +
src/relax/op/tensor/binary.h | 3 +-
src/relax/op/tensor/datatype.cc | 3 +-
src/relax/op/tensor/index.cc | 23 +-
src/relax/op/tensor/manipulate.cc | 189 ++-
src/relax/op/tensor/statistical.cc | 52 +
src/relax/op/tensor/statistical.h | 27 +-
src/relax/op/tensor/ternary.cc | 18 +-
src/relax/transform/convert_layout.cc | 309 +++++
src/relax/transform/infer_layout_utils.cc | 126 ++
src/relax/transform/infer_layout_utils.h | 244 ++++
src/relax/transform/utils.cc | 32 +
src/relax/transform/utils.h | 42 +
.../python/relax/test_transform_convert_layout.py | 1352 ++++++++++++++++++++
23 files changed, 2711 insertions(+), 36 deletions(-)
diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h
index 93fc9a36c5..0564c26687 100644
--- a/include/tvm/relax/nested_msg.h
+++ b/include/tvm/relax/nested_msg.h
@@ -531,6 +531,50 @@ Expr TransformTupleLeaf(Expr expr,
std::array<NestedMsg<T>, N> msgs, FType ftran
}
}
+/*!
+ * \brief Recursively transform the tuple structure in sinfo and msgs along
with it.
+ *
+ * This function will call ftransleaf for each leaf sinfo in sinfo.
+ * This function will throw an error if the nesting structure in msg does not
+ * match the tuple nesting structure in sinfo.
+ *
+ * \param sinfo The input sinfo to be transform.
+ * \param msgs The input messages to guide the transformation.
+ * \param ftransleaf with signature ftransleaf(StructInfo,
Array<NestedMsg<T>>)->StructInfo
+ * \tparam T the content type of nested msg
+ * \tparam N the number of messages
+ * \tparam FType The visit function type.
+ */
+template <typename T, std::size_t N, typename FType>
+StructInfo TransformTupleLeaf(StructInfo sinfo, std::array<NestedMsg<T>, N>
msgs,
+ FType ftransleaf) {
+ if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
+ std::array<Array<NestedMsg<T>>, N> msg_arrays;
+ for (size_t i = 0; i < N; ++i) {
+ ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
+ msg_arrays[i] = msgs[i].NestedArray();
+ }
+ bool same = true;
+ Array<StructInfo> fields;
+ fields.reserve(tuple->fields.size());
+ for (size_t i = 0; i < tuple->fields.size(); ++i) {
+ StructInfo field = tuple->fields[i];
+ std::array<NestedMsg<T>, N> sub_msgs;
+ for (size_t j = 0; j < N; ++j) {
+ sub_msgs[j] = msg_arrays[j][i];
+ }
+ fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs),
ftransleaf));
+ same &= (fields.back().same_as(field));
+ }
+ return same ? sinfo : TupleStructInfo(fields);
+ } else {
+ for (const auto& msg : msgs) {
+ ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
+ }
+ return ftransleaf(sinfo, msgs);
+ }
+}
+
} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_NESTED_MSG_H_
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index e0fe226e83..ead8b0c31e 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -369,6 +369,7 @@ TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String,
ObjectRef>>> target_opt
* \return The Pass.
*/
TVM_DLL Pass SimplifyNormInference();
+
/*!
* \brief Returns a pass which replaces PrimFuncs which have matching
kOperatorName attribute in \p
* op_impl_map, with replacement PrimFunc that could possibly have different
layouts on i/o
@@ -383,6 +384,14 @@ TVM_DLL Pass SimplifyNormInference();
*/
TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<tir::IndexMap>>&
op_buffer_transforms);
+
+/*!
+ * \brief Layout conversion pass.
+ * \param desired_layouts The desired layouts for some operators.
+ * \return The Pass.
+ */
+TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
+
} // namespace transform
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 0df29dc093..c10d0130c1 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -640,6 +640,23 @@ def AlterOpImpl(
return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms) # type:
ignore
+def ConvertLayout(desired_layouts: Dict[str, List[str]]) ->
tvm.ir.transform.Pass:
+ """Automatic layout conversion pass.
+ Parameters
+ ----------
+ desired_layouts : Dict[str, List[str]]
+ The desired layout of conv2d ops is a map from the name of the op to
the desired layout
+ of the desired feature map, weight and output. For example, if we want
to convert the
+ layout of conv2d from NCHW to NHWC, we can set the desired layout of
conv2d to be
+ {"conv2d": ["NHWC", "OHWI"]}.
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for layout conversion.
+ """
+ return _ffi_api.ConvertLayout(desired_layouts) # type: ignore
+
+
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 2711b3cc45..de6eec6236 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -102,12 +102,26 @@ StructInfo InferStructInfoResize2D(const Call& call,
const BlockBuilder& ctx) {
return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
}
+InferLayoutOutput InferLayoutResize2d(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* attrs = call->attrs.as<Resize2DAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<Resize2DAttrs> new_attrs = make_object<Resize2DAttrs>(*attrs);
+ new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4),
layout->layout).name();
+ return InferLayoutOutput({layout, InitialNLayout(call->args[1])}, {layout},
Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.image.resize2d")
.set_attrs_type<Resize2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("size", "Shape", "The output image shape.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize2D);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize2D)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutResize2d);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index 8dc3c9696f..f356876620 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -135,12 +135,55 @@ StructInfo InferStructInfoConv2d(const Call& call, const
BlockBuilder& ctx) {
return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
}
+InferLayoutOutput InferLayoutConv2d(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ const auto& it = desired_layouts.find("relax.nn.conv2d");
+ const auto* attrs = call->attrs.as<Conv2DAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision data_layout, weight_layout, output_layout;
+ ObjectPtr<Conv2DAttrs> new_attrs = make_object<Conv2DAttrs>(*attrs);
+
+ if (it != desired_layouts.end()) {
+ // We have a desired layout for conv2d.
+ Layout desired_data_layout = (*it).second[0];
+ Layout desired_weight_layout = (*it).second[1];
+ Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2]
: (*it).second[0];
+ ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal())
<< "Axis swap only";
+ ICHECK_EQ(desired_weight_layout.ndim(),
desired_weight_layout.ndim_primal())
+ << "Axis swap only";
+ ICHECK_EQ(desired_output_layout.ndim(),
desired_output_layout.ndim_primal())
+ << "Axis swap only";
+ data_layout = TransposeLike(InitialLayout(4), attrs->data_layout,
desired_data_layout);
+ weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout,
desired_weight_layout);
+ output_layout = TransposeLike(InitialLayout(4), attrs->out_layout,
desired_output_layout);
+ new_attrs->data_layout = (*it).second[0];
+ new_attrs->kernel_layout = (*it).second[1];
+ new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] :
(*it).second[0];
+ } else {
+ // We don't have a desired layout for conv2d.
+ // We can just propagate the layout from the input.
+ data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
+ output_layout = data_layout;
+ new_attrs->data_layout =
+ TransposeLike(attrs->data_layout, InitialLayout(4),
data_layout->layout).name();
+ new_attrs->kernel_layout =
+ TransposeLike(attrs->kernel_layout, InitialLayout(4),
weight_layout->layout).name();
+ new_attrs->out_layout =
+ TransposeLike(attrs->out_layout, InitialLayout(4),
output_layout->layout).name();
+ }
+ return InferLayoutOutput({data_layout, weight_layout}, {output_layout},
Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.nn.conv2d")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_attrs_type<Conv2DAttrs>()
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2d);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2d)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv2d);
/* relax.nn.conv2d_transpose */
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 430d2268ce..6bce51ca50 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -62,11 +62,25 @@ StructInfo InferStructInfoSoftmax(const Call& call, const
BlockBuilder& ctx) {
return data_sinfo;
}
+InferLayoutOutput InferLayoutSoftmax(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* attrs = call->attrs.as<SoftmaxAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<SoftmaxAttrs> new_attrs = make_object<SoftmaxAttrs>(*attrs);
+ new_attrs->axis = FindAxis(layout->layout, attrs->axis);
+ return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.nn.softmax")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_attrs_type<SoftmaxAttrs>()
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSoftmax);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSoftmax)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSoftmax);
/* relax.nn.log_softmax */
Expr log_softmax(Expr data, int axis) {
@@ -188,6 +202,28 @@ StructInfo InferStructInfoBatchNorm(const Call& call,
const BlockBuilder& ctx) {
}
}
+InferLayoutOutput InferLayoutBatchNorm(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ std::vector<NLayout> initial_layouts;
+ for (size_t i = 0; i < 5; ++i) {
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+ initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
+ }
+ const auto* attrs = call->attrs.as<BatchNormAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<BatchNormAttrs> new_attrs = make_object<BatchNormAttrs>(*attrs);
+ new_attrs->axis = FindAxis(layout->layout, attrs->axis);
+ return InferLayoutOutput(
+ {layout, initial_layouts[1], initial_layouts[2], initial_layouts[3],
initial_layouts[4]},
+ {{layout, initial_layouts[3], initial_layouts[4]}}, Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.nn.batch_norm")
.set_attrs_type<BatchNormAttrs>()
.set_num_inputs(5)
@@ -196,7 +232,8 @@ TVM_REGISTER_OP("relax.nn.batch_norm")
.add_argument("beta", "Tensor", "The beta offset factor.")
.add_argument("moving_mean", "Tensor", "Running mean of input.")
.add_argument("moving_var", "Tensor", "Running variance of input.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBatchNorm);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBatchNorm)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutBatchNorm);
/* relax.nn.layer_norm */
TVM_REGISTER_NODE_TYPE(LayerNormAttrs);
@@ -225,13 +262,39 @@ StructInfo InferStructInfoLayerNorm(const Call& call,
const BlockBuilder& ctx) {
: input_sinfo[0];
}
+InferLayoutOutput InferLayoutLayerNorm(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ std::vector<NLayout> initial_layouts;
+ for (size_t i = 0; i < 3; ++i) {
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+ initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
+ }
+ const auto* attrs = call->attrs.as<LayerNormAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<LayerNormAttrs> new_attrs = make_object<LayerNormAttrs>(*attrs);
+ std::vector<Integer> new_axis;
+ for (const auto& axis : attrs->axes) {
+ new_axis.push_back(FindAxis(layout->layout, axis->value));
+ }
+ new_attrs->axes = std::move(new_axis);
+ return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]},
{layout},
+ Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.nn.layer_norm")
.set_attrs_type<LayerNormAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which batch_norm will be
applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutLayerNorm);
/* relax.nn.group_norm */
TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
@@ -308,13 +371,40 @@ StructInfo InferStructInfoGroupNorm(const Call& call,
const BlockBuilder& ctx) {
return data_sinfo;
}
+InferLayoutOutput InferLayoutGroupNorm(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ std::vector<NLayout> initial_layouts;
+ for (size_t i = 0; i < 3; ++i) {
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+ initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
+ }
+ const auto* attrs = call->attrs.as<GroupNormAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<GroupNormAttrs> new_attrs = make_object<GroupNormAttrs>(*attrs);
+ std::vector<Integer> new_axes;
+ for (const auto& axis : attrs->axes) {
+ new_axes.push_back(FindAxis(layout->layout, axis->value));
+ }
+ new_attrs->axes = std::move(new_axes);
+ new_attrs->channel_axis = FindAxis(layout->layout, attrs->channel_axis);
+ return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]},
{layout},
+ Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.nn.group_norm")
.set_attrs_type<GroupNormAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which batch_norm will be
applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutGroupNorm);
/* relax.nn.dropout */
TVM_REGISTER_NODE_TYPE(DropoutAttrs);
@@ -338,7 +428,8 @@ TVM_REGISTER_OP("relax.nn.dropout")
.set_attrs_type<DropoutAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input to which dropout will be applied.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDropout);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDropout)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise);
/* relax.nn.cross_entropy_with_logits */
StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder&
ctx) {
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index 61001ce678..be0a794dee 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -117,11 +117,29 @@ StructInfo InferStructInfoPool2D(const Call& call, const
BlockBuilder& ctx) {
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
}
+InferLayoutOutput InferLayoutPool2d(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout";
+ const auto* attrs = call->attrs.as<Pool2DAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<Pool2DAttrs> new_attrs = make_object<Pool2DAttrs>(*attrs);
+ new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4),
layout->layout).name();
+ new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4),
layout->layout).name();
+ return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.nn.max_pool2d")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor")
.set_attrs_type<Pool2DAttrs>()
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d);
Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
Array<IntImm> dilation, bool ceil_mode, String layout,
@@ -136,7 +154,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor")
.set_attrs_type<Pool2DAttrs>()
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d);
/* relax.nn.adaptive_avg_pool2d */
TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs);
@@ -196,11 +215,29 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call&
call, const BlockBuilder
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
}
+InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call,
+ const Map<String,
Array<String>>& desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout";
+ const auto* attrs = call->attrs.as<AdaptivePool2DAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<AdaptivePool2DAttrs> new_attrs =
make_object<AdaptivePool2DAttrs>(*attrs);
+ new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4),
layout->layout).name();
+ new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4),
layout->layout).name();
+ return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
.set_attrs_type<AdaptivePool2DAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor")
- .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoAdaptiveAvgPool2D);
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoAdaptiveAvgPool2D)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout",
InferLayoutAdaptiveAvgPool2D);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc
index c82c325d9b..0421c957ca 100644
--- a/src/relax/op/op_common.cc
+++ b/src/relax/op/op_common.cc
@@ -117,5 +117,13 @@ std::vector<int> NormalizeAxes(const Call& call, const
BlockBuilder& ctx, int nd
return axes_non_neg;
}
+InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs));
+}
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 29e02946c6..ece4c4a321 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -34,6 +34,8 @@
#include <utility>
#include <vector>
+#include "../transform/infer_layout_utils.h"
+
namespace tvm {
namespace relax {
@@ -68,10 +70,11 @@ inline TensorStructInfo GetUnaryInputTensorStructInfo(const
Call& call, const Bl
* \param OpRegName The name of operator to register. The name passed in will
* be prepended with a prefix "relax." as the identifier string in the
operator registry.
*/
-#define RELAX_REGISTER_UNARY_OP(OpRegName) \
- TVM_REGISTER_OP("relax." OpRegName) \
- .set_num_inputs(1) \
- .add_argument("x", "Tensor", "The input tensor.")
+#define RELAX_REGISTER_UNARY_OP(OpRegName) \
+ TVM_REGISTER_OP("relax." OpRegName) \
+ .set_num_inputs(1) \
+ .add_argument("x", "Tensor", "The input tensor.") \
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
/*!
* \brief Quick helper macro to expose a make-function to construct the
operator.
@@ -151,6 +154,17 @@ StructInfo InferStructInfoUnaryArith(const Call& call,
const BlockBuilder& ctx)
call, ctx, [](const TensorStructInfo& input_sinfo) { return
input_sinfo->dtype; });
}
+/*!
+ * \brief Layout infer util for unary elementwise ops. It will simply take the
layout of the input.
+ * \param call The context Call to the operator.
+ * \param desired_layouts The desired layouts of certain ops.
+ * \param var_layout_map The layout of vars.
+ * \return The inferred layout result.
+ */
+InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map);
+
/*!
* \brief Infer the output datatype for binary arithmetic operators.
* \param call The context Call to the operator.
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index 7e8480ee16..30cd748308 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -78,6 +78,28 @@ StructInfo InferStructInfoBroadcastCMP(const Call& call,
const BlockBuilder& ctx
const TensorStructInfo& x2_sinfo) { return DataType::Bool(); });
}
+InferLayoutOutput InferLayoutBinaryEwise(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[0]);
+ LayoutDecision layout2 = GetLayoutDecision(var_layout_map, call->args[1]);
+
+ auto* x1_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ auto* x2_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+
+ ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim())
+ << "Unknown dim tensors should not be handled by this function";
+
+ if (x1_sinfo->ndim <= x2_sinfo->ndim) {
+ LayoutDecision out_layout = FollowDecision(layout1, x2_sinfo->ndim);
+ return InferLayoutOutput({layout1, out_layout}, {out_layout},
Attrs(call->attrs));
+ } else {
+ LayoutDecision out_layout = FollowDecision(layout2, x1_sinfo->ndim);
+ return InferLayoutOutput({out_layout, layout2}, {out_layout},
Attrs(call->attrs));
+ }
+}
+
/***************** Arithmetic operators *****************/
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add);
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index 0a48e727e6..197110c000 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -46,7 +46,8 @@ namespace relax {
TVM_REGISTER_OP("relax." #OpName) \
.set_num_inputs(2) \
.add_argument("x1", "Tensor", "The first input tensor.") \
- .add_argument("x2", "Tensor", "The second input tensor.")
+ .add_argument("x2", "Tensor", "The second input tensor.") \
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutBinaryEwise)
#define RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(OpName) \
RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr<FInferStructInfo>( \
diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc
index 0c647aa866..349a54ee4d 100644
--- a/src/relax/op/tensor/datatype.cc
+++ b/src/relax/op/tensor/datatype.cc
@@ -54,7 +54,8 @@ TVM_REGISTER_OP("relax.astype")
.set_attrs_type<AstypeAttrs>()
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAstype);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAstype)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 246abef908..29f668ccf3 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -185,11 +185,32 @@ StructInfo InferStructInfoStridedSlice(const Call& call,
const BlockBuilder& ctx
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}
+InferLayoutOutput InferLayoutStridedSlice(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ const auto* attrs = call->attrs.as<StridedSliceAttrs>();
+ ICHECK(attrs != nullptr) << "Invalid Call";
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+ LayoutDecision existing_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ std::vector<Integer> new_axes;
+ for (const auto& axis : attrs->axes) {
+ new_axes.push_back(FindAxis(existing_layout->layout, axis->value));
+ }
+ ObjectPtr<StridedSliceAttrs> new_attrs =
make_object<StridedSliceAttrs>(*attrs);
+ new_attrs->axes = std::move(new_axes);
+ return InferLayoutOutput({existing_layout}, {existing_layout},
Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.strided_slice")
.set_attrs_type<StridedSliceAttrs>()
.set_num_inputs(1)
.add_argument("x", "Tensor", "The source tensor to be sliced.")
- .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoStridedSlice);
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoStridedSlice)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index c7bf051302..49f745608f 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -26,6 +26,7 @@
#include <algorithm>
#include <numeric>
+#include <string>
#include <utility>
#include <vector>
@@ -272,11 +273,35 @@ StructInfo InferStructInfoConcat(const Call& call, const
BlockBuilder& ctx) {
}
}
+InferLayoutOutput InferLayoutConcat(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ const auto* attrs = call->attrs.as<ConcatAttrs>();
+ ICHECK(attrs != nullptr) << "Invalid Call";
+ NLayout nlayout = GetNLayout(var_layout_map, call->args[0]);
+ ICHECK(nlayout.IsNested());
+ ICHECK(nlayout.NestedArray()[0].IsLeaf());
+
+ int n_tensor = nlayout.NestedArray().size();
+ LayoutDecision layout = nlayout.NestedArray()[0].LeafValue();
+ Array<NLayout> input_layouts, output_layouts;
+ for (int i = 0; i < n_tensor; ++i) {
+ input_layouts.push_back(layout);
+ }
+ output_layouts.push_back(layout);
+ ObjectPtr<ConcatAttrs> new_attrs = make_object<ConcatAttrs>(*attrs);
+ new_attrs->axis = Integer(FindAxis(layout->layout,
attrs->axis.value_or(0)->value));
+ return InferLayoutOutput({NLayout(input_layouts)}, output_layouts,
Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.concat")
.set_attrs_type<ConcatAttrs>()
.set_num_inputs(1)
.add_argument("tensors", "Tuple of Tensors", "The input list of tensors.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConcat);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConcat)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConcat);
/* relax.expand_dims */
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
@@ -330,11 +355,49 @@ StructInfo InferStructInfoExpandDims(const Call& call,
const BlockBuilder& ctx)
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}
+InferLayoutOutput InferLayoutExpandDims(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* attrs = call->attrs.as<ExpandDimsAttrs>();
+ ICHECK(attrs != nullptr) << "Invalid Call";
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+
+ LayoutDecision existing_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ int ndim = tensor_sinfo->ndim;
+ int n_new_dim = attrs->axis.size();
+ int output_ndim = ndim + n_new_dim;
+ std::vector<bool> is_new_dim(output_ndim, false);
+ for (const auto& axis : attrs->axis) {
+ is_new_dim[(axis->value + output_ndim) % output_ndim] = true;
+ }
+ std::string new_layout;
+ for (int i = 0; i < output_ndim; ++i) {
+ if (!is_new_dim[i]) {
+ new_layout.push_back('A' + i);
+ }
+ }
+ new_layout = TransposeStrLike(new_layout, InitialLayout(ndim),
existing_layout->layout);
+ std::string output_layout;
+ for (int i = 0, j = 0; i < output_ndim; ++i) {
+ if (is_new_dim[i]) {
+ output_layout.push_back('A' + i);
+ } else {
+ output_layout.push_back(new_layout.at(j++));
+ }
+ }
+ return InferLayoutOutput({existing_layout},
{LayoutDecision(Layout(output_layout))},
+ Attrs(call->attrs));
+}
+
TVM_REGISTER_OP("relax.expand_dims")
.set_num_inputs(1)
.set_attrs_type<ExpandDimsAttrs>()
.add_argument("x", "Tensor", "The input tensor.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoExpandDims);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoExpandDims)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutExpandDims);
// Helper function for flatten and reshape.
PrimExpr ComputeShapeProduct(const Array<PrimExpr>& shape_values) {
@@ -505,11 +568,49 @@ StructInfo InferStructInfoPermuteDims(const Call& call,
const BlockBuilder& ctx)
return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype);
}
+InferLayoutOutput InferLayoutPermuteDims(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ const auto* attrs = call->attrs.as<PermuteDimsAttrs>();
+ ICHECK(attrs != nullptr) << "Invalid Call";
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+ int ndim = tensor_sinfo->ndim;
+
+ LayoutDecision existing_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ Array<Integer> order;
+ if (attrs->axes.defined()) {
+ order = attrs->axes.value();
+ } else {
+ order.reserve(ndim);
+ for (int i = 0; i < ndim; ++i) {
+ order.push_back(Integer(ndim - i - 1));
+ }
+ }
+ std::string order_str;
+ for (const auto& axis : order) {
+ order_str.push_back(axis->value + 'A');
+ }
+ String new_axes =
+ TransposeStrLike(InitialLayout(ndim).name(), existing_layout->layout,
order_str);
+ Array<Integer> new_order;
+ for (size_t i = 0; i < new_axes.size(); ++i) {
+ new_order.push_back(Integer(new_axes.at(i) - 'A'));
+ }
+ ObjectPtr<PermuteDimsAttrs> new_attrs =
make_object<PermuteDimsAttrs>(*attrs);
+ new_attrs->axes = new_order;
+ return InferLayoutOutput({existing_layout}, {InitialLayoutDecision(ndim)},
Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.permute_dims")
.set_attrs_type<PermuteDimsAttrs>()
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
- .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoPermuteDims);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPermuteDims)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPermuteDims);
/* relax.reshape */
Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
@@ -758,11 +859,33 @@ StructInfo InferStructInfoSplit(const Call& call, const
BlockBuilder& ctx) {
throw;
}
+InferLayoutOutput InferLayoutSplit(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ const auto* attrs = call->attrs.as<SplitAttrs>();
+ ICHECK(attrs != nullptr) << "Invalid Call";
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+
+ LayoutDecision existing_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ ObjectPtr<SplitAttrs> new_attrs = make_object<SplitAttrs>(*attrs);
+ new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis);
+ StructInfo out_sinfo = InferStructInfoSplit(call,
BlockBuilder::Create(IRModule()));
+ const auto* out_tuple = out_sinfo.as<TupleStructInfoNode>();
+ ICHECK(out_tuple != nullptr) << "Invalid Call";
+ NLayout tuple_layouts(Array<NLayout>(out_tuple->fields.size(),
existing_layout));
+ return InferLayoutOutput({existing_layout}, {tuple_layouts},
Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.split")
.set_attrs_type<SplitAttrs>()
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSplit);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSplit)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSplit);
/* relax.squeeze */
TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
@@ -857,11 +980,67 @@ StructInfo InferStructInfoSqueeze(const Call& call, const
BlockBuilder& ctx) {
}
}
+InferLayoutOutput InferLayoutSqueeze(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ const auto* attrs = call->attrs.as<SqueezeAttrs>();
+ ICHECK(attrs != nullptr) << "Invalid Call";
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+ ICHECK(tensor_sinfo->shape.defined()) << "Only support static shape for now";
+ int ndim = tensor_sinfo->ndim;
+ const auto* shape = tensor_sinfo->shape.as<ShapeExprNode>();
+ ICHECK(shape != nullptr) << "Only support static shape for now";
+
+ Array<Integer> axis;
+ if (attrs->axis.defined()) {
+ axis = attrs->axis.value();
+ } else {
+ axis.reserve(ndim);
+ for (int i = 0; i < ndim; ++i) {
+ if (tir::is_one(shape->values[i])) {
+ axis.push_back(Integer(i));
+ }
+ }
+ }
+
+ std::string axis_str(ndim, '0');
+ for (const auto& iter : axis) {
+ axis_str[iter->value] = '1';
+ }
+ for (int i = 0, j = 0; i < ndim; ++i) {
+ if (axis_str[i] != '1') {
+ axis_str[i] = 'A' + j++;
+ }
+ }
+
+ LayoutDecision existing_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim),
existing_layout->layout);
+ Array<Integer> new_axis;
+ for (size_t i = 0; i < new_axis_str.size(); ++i) {
+ if (new_axis_str.at(i) == '1') {
+ new_axis.push_back(Integer(i));
+ }
+ }
+ std::string output_layout = new_axis_str;
+ output_layout.erase(std::remove(output_layout.begin(), output_layout.end(),
'1'),
+ output_layout.end());
+
+ ObjectPtr<SqueezeAttrs> new_attrs = make_object<SqueezeAttrs>(*attrs);
+ new_attrs->axis = new_axis;
+ return InferLayoutOutput({existing_layout},
{LayoutDecision(Layout(output_layout))},
+ Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.squeeze")
.set_num_inputs(1)
.set_attrs_type<SqueezeAttrs>()
.add_argument("x", "Tensor", "The input tensor.")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSqueeze);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSqueeze)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSqueeze);
void CheckCollapseShape(const Call& call, const BlockBuilder& ctx,
const Array<PrimExpr>& data_shape, const
Array<PrimExpr>& target_shape) {
diff --git a/src/relax/op/tensor/statistical.cc
b/src/relax/op/tensor/statistical.cc
index 41b99fbe36..4de8e3dd63 100644
--- a/src/relax/op/tensor/statistical.cc
+++ b/src/relax/op/tensor/statistical.cc
@@ -24,6 +24,7 @@
#include "statistical.h"
+#include <string>
#include <vector>
namespace tvm {
@@ -82,6 +83,57 @@ StructInfo InferStructInfoStatistical(const Call& call,
const BlockBuilder& ctx)
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
}
+InferLayoutOutput InferLayoutStatistical(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ const auto* attrs = call->attrs.as<StatisticalAttrs>();
+ ICHECK(attrs != nullptr) << "Invalid Call";
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+ int ndim = tensor_sinfo->ndim;
+
+ Array<Integer> axis;
+ if (attrs->axis.defined()) {
+ axis = attrs->axis.value();
+ } else {
+ axis.reserve(ndim);
+ for (int i = 0; i < ndim; ++i) {
+ axis.push_back(Integer(i));
+ }
+ }
+
+ std::string axis_str(ndim, '0');
+ for (const auto& iter : axis) {
+ axis_str[iter->value] = '1';
+ }
+ for (int i = 0, j = 0; i < ndim; ++i) {
+ if (axis_str[i] != '1') {
+ axis_str[i] = 'A' + j++;
+ }
+ }
+
+ LayoutDecision exisiting_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim),
exisiting_layout->layout);
+ Array<Integer> new_axis;
+ for (size_t i = 0; i < new_axis_str.size(); ++i) {
+ if (new_axis_str.at(i) == '1') {
+ new_axis.push_back(Integer(i));
+ }
+ }
+ std::string output_layout = new_axis_str;
+ output_layout.erase(std::remove(output_layout.begin(), output_layout.end(),
'1'),
+ output_layout.end());
+
+ ObjectPtr<StatisticalAttrs> new_attrs =
make_object<StatisticalAttrs>(*attrs);
+ new_attrs->axis = new_axis;
+ return InferLayoutOutput({exisiting_layout},
+ {attrs->keepdims ? exisiting_layout :
Layout(output_layout)},
+ Attrs(new_attrs));
+}
+
TVM_REGISTER_NODE_TYPE(StatisticalAttrs);
RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max);
diff --git a/src/relax/op/tensor/statistical.h
b/src/relax/op/tensor/statistical.h
index 7d322d1129..0adeb82259 100644
--- a/src/relax/op/tensor/statistical.h
+++ b/src/relax/op/tensor/statistical.h
@@ -42,19 +42,20 @@ namespace relax {
* 1. be prepended with a prefix "relax.op." as the FFI identifier string for
the make function,
* 2. be prepended with a prefix "relax." as the identifier string in the
operator registry.
*/
-#define RELAX_REGISTER_STATISTICAL_OP_INTERFACE(OpName) \
- Expr OpName(Expr x, Optional<Array<Integer>> axis, bool keepdims) { \
- ObjectPtr<StatisticalAttrs> attrs = make_object<StatisticalAttrs>(); \
- attrs->axis = std::move(axis); \
- attrs->keepdims = keepdims; \
- static const Op& op = Op::Get("relax." #OpName); \
- return Call(op, {std::move(x)}, Attrs{attrs}, {}); \
- } \
- TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \
- TVM_REGISTER_OP("relax." #OpName) \
- .set_num_inputs(1) \
- .add_argument("x", "Tensor", "The input data tensor") \
- .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoStatistical)
+#define RELAX_REGISTER_STATISTICAL_OP_INTERFACE(OpName)
\
+ Expr OpName(Expr x, Optional<Array<Integer>> axis, bool keepdims) {
\
+ ObjectPtr<StatisticalAttrs> attrs = make_object<StatisticalAttrs>();
\
+ attrs->axis = std::move(axis);
\
+ attrs->keepdims = keepdims;
\
+ static const Op& op = Op::Get("relax." #OpName);
\
+ return Call(op, {std::move(x)}, Attrs{attrs}, {});
\
+ }
\
+ TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName);
\
+ TVM_REGISTER_OP("relax." #OpName)
\
+ .set_num_inputs(1)
\
+ .add_argument("x", "Tensor", "The input data tensor")
\
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoStatistical) \
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStatistical)
/*!
* \brief Computes the maximum value of tensor elements over given axes.
diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc
index 8820c07afd..93652f43ef 100644
--- a/src/relax/op/tensor/ternary.cc
+++ b/src/relax/op/tensor/ternary.cc
@@ -90,12 +90,28 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const
BlockBuilder& ctx) {
return TensorStructInfo(output_dtype, ndim);
}
+InferLayoutOutput InferLayoutEwiseFMA(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ LayoutDecision layout0 = GetLayoutDecision(var_layout_map, call->args[0]);
+ LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[1]);
+ LayoutDecision layout2 = GetLayoutDecision(var_layout_map, call->args[2]);
+ LayoutDecision layout = layout0;
+ if (NLayoutEqual()(layout1, layout2)) {
+ layout = layout1;
+ }
+ return InferLayoutOutput({layout, layout, layout}, {layout},
Attrs(call->attrs));
+}
+
TVM_REGISTER_OP("relax.ewise_fma")
.set_num_inputs(3)
.add_argument("x1", "Tensor", "The left hand operand of the
multiplication")
.add_argument("x2", "Tensor", "The right hand operand of the
multiplication")
.add_argument("x3", "Tensor", "The operand of the addition")
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEwiseFMA);
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEwiseFMA)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutEwiseFMA);
Expr ewise_fma(Expr x1, Expr x2, Expr x3) {
static const Op& op = Op::Get("relax.ewise_fma");
diff --git a/src/relax/transform/convert_layout.cc
b/src/relax/transform/convert_layout.cc
new file mode 100644
index 0000000000..4f36cfbc0f
--- /dev/null
+++ b/src/relax/transform/convert_layout.cc
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/convert_layout.cc
+ * \brief Automatic layout conversion pass, especially for axis swapping.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/relax/transform.h>
+
+#include "../op/tensor/manipulate.h"
+#include "infer_layout_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+using tir::Layout;
+
+/*!
+ * \brief Main logic to convert the layout of conv2d. Other ops
+ * can adapt to such layout conversion following conv2d accordingly.
+ *
+ * Structurally speaking, a Relax function is composed of a series of
VarBinding and
+ * MatchCast. And a specific class of VarBindings is the basic unit we want to
rewrite.
+ * Formally, they are of the form:
+ *
+ * var = Call(Op, [args], attrs)
+ *
+ * where Op is a specific op we want to rewrite, and attrs is the attributes
of the op.
+ * var and args are all exprs with type Tensor or Tuple of Tensors. They might
+ * be vars, constants, or Tuple of vars and constants.
+ *
+ * We register the layout inference function for each op (FRelaxInferLayout),
which accepts the
+ * current call, the desired layout of conv2d ops, and the layout map of
previous vars. The result
+ * of the layout inference function is contained in an InferLayoutOutput
object, which contains 3
+ * fields: input_layouts, output_layouts, and attr, which represents the
expected input layout,
+ * output_layout and converted attrs of the new op call.
+ *
+ * The rewrite pass does the rewriting in a single forward pass, where for
each Call(Op),
+ * we collect the current Layout of each input var, and let the InferLayout
function to infer the
+ * desired layout of the output. The rewriter will use these info to convert
+ * the layout of inputs and attrs of the op call, and note down the new layout
of the output.
+ *
+ * The desired layout of conv2d ops is a map from the name of the op to the
desired layout of the
+ * desired feature map, weight and output. For example, if we want to convert
the layout of conv2d
+ * from NCHW to NHWC, we can set the desired layout of conv2d to be {"conv2d":
["NHWC", "OHWI"]}.
+ *
+ * The way we represent the layout of a var is a NLayout object, which is a
nested tuple of Layout.
+ * The incoming layout of the module will be set as the default layout (We use
ABCD... as the
+ * default) Note that for operators like conv, pool, people typically use NHWC
to refer to the axes.
+ * But to be generic and support more operators, we use ABCD... to refer to
the axes.
+ *
+ * Note that currently the layout conversion of conv2d only support axis
swapping, such as NCHW to
+ * NWHC. Packed layout such as NCHW to NCHW4c is not supported now.
+ */
+class LayoutConvertMutator : public ExprMutator {
+ public:
+ explicit LayoutConvertMutator(const Map<String, Array<String>>&
desired_layouts)
+ : desired_layouts_(desired_layouts) {}
+
+ private:
+ Array<Integer> LayoutToIntegers(const Layout& layout) {
+ Array<Integer> ret;
+ LayoutDecision src = InitialLayoutDecision(layout.ndim());
+ for (size_t i = 0; i < layout.ndim(); ++i) {
+ ret.push_back(Integer(src->layout.IndexOf(layout[i])));
+ }
+ return ret;
+ }
+
+ Expr RewriteExpr(const Expr& expr, const NLayout& to) {
+ auto fvisitleaf = [&](const Expr& expr, std::array<NLayout, 2> layouts) ->
Expr {
+ NLayout from = layouts[0], to = layouts[1];
+ if (NLayoutEqual()(from, to)) return expr;
+ // If not both from and to are unknown, then none of them can be unknown.
+ ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) &&
+ !NLayoutEqual()(to, LayoutDecision::InitUnknownDim()))
+ << "Cannot convert when exactly one of the layouts is unknown";
+ const auto* tensor = GetStructInfoAs<TensorStructInfoNode>(expr);
+ ICHECK(tensor != nullptr) << "Expect a tensor, but got: " << expr;
+ Layout axes = TransposeLike(InitialLayoutDecision(tensor->ndim)->layout,
+ from.LeafValue()->layout,
to.LeafValue()->layout);
+ return permute_dims(expr, LayoutToIntegers(axes));
+ };
+ return TransformTupleLeaf<LayoutDecision>(
+ VarReplacer::Replace(expr, var_remap_),
+ std::array<NLayout, 2>({GetNLayout(var_layout_map_, expr), to}),
fvisitleaf);
+ }
+
+ Array<Expr> RewriteArgs(const Array<Expr>& args, const Array<NLayout>& to) {
+ ICHECK(args.size() == to.size());
+ std::vector<Expr> new_args;
+ for (size_t i = 0; i < args.size(); ++i) {
+ new_args.push_back(RewriteExpr(args[i], to[i]));
+ }
+ return std::move(new_args);
+ }
+
+ void VisitBinding(const Binding& binding) final {
+ // Emit the binding
+ ExprMutator::VisitBinding(binding);
+ // The layout is default to be initial if not rewritten.
+ if (var_layout_map_.find(binding->var) == var_layout_map_.end()) {
+ var_layout_map_[binding->var] = InitialNLayout(binding->var);
+ }
+ }
+
+ Expr VisitVars_(const Var& var) {
+ // We encounter a var use outside of inferrable regions, we rewrite it to
initial layout.
+ return RewriteExpr(var, InitialNLayout(var));
+ }
+
+ Expr VisitExpr_(const VarNode* op) final { return
VisitVars_(GetRef<Var>(op)); }
+
+ Expr VisitExpr_(const DataflowVarNode* op) final { return
VisitVars_(GetRef<Var>(op)); }
+
+ bool HasUnknownDimTensor(const NLayout& nlayout) {
+ bool find = false;
+ auto fvisit = [&](const LayoutDecision& layout) {
+ find = find | (NLayoutEqual()(layout, LayoutDecision::InitUnknownDim()));
+ };
+ ForEachLeaf<LayoutDecision>(nlayout, fvisit);
+ return find;
+ }
+
+ bool HasUnknownDimTensor(const Array<Expr>& args) {
+ for (const auto& arg : args) {
+ if (IsNestedTensor(arg)) {
+ if (HasUnknownDimTensor(GetNLayout(var_layout_map_, arg))) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ Optional<InferLayoutOutput> GetInferLayoutInfo(const CallNode* call_node,
+ const Map<String,
Array<String>>& desired_layouts,
+ const VarLayoutMap&
var_layout_map) {
+ const OpNode* op_node = call_node->op.as<OpNode>();
+ if (op_node == nullptr) return NullOpt;
+ Op op = Downcast<Op>(GetRef<Op>(op_node));
+ const auto attr_map =
Op::GetAttrMap<FRelaxInferLayout>("FRelaxInferLayout");
+ if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) {
+ // If the op has FRelaxInferLayout, and all the input tensors have known
ndim
+ FRelaxInferLayout f = attr_map[op];
+ return f(GetRef<Call>(call_node), desired_layouts, var_layout_map);
+ } else {
+ // Otherwise, we use the default policy.
+ return NullOpt;
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node)
final {
+ Optional<InferLayoutOutput> res =
+ GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_);
+ ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node);
+ new_call->struct_info_ = NullOpt;
+ if (!res.defined() ||
+ (!IsNestedTensor(binding->var) &&
!binding->var->IsInstance<DataflowVarNode>())) {
+ // Default policy: use the initial layout.
+ // When we don't have the infer layout info, or it's a non-tensor global
var binding.
+ std::vector<NLayout> input_layout;
+ for (const auto& arg : call_node->args) {
+ input_layout.push_back(InitialNLayout(arg));
+ }
+ Array<Expr> new_args = RewriteArgs(call_node->args,
std::move(input_layout));
+ new_call->args = std::move(new_args);
+ ReEmitBinding(binding, builder_->Normalize(Call(new_call)));
+ // update the layout map
+ var_layout_map_[binding->var] = InitialNLayout(binding->var);
+ } else {
+ // Convert the layout according to the inferred layout output.
+ Array<Expr> new_args = RewriteArgs(call_node->args,
res.value()->input_layouts);
+ new_call->args = std::move(new_args);
+ new_call->attrs = std::move(res.value()->new_attrs);
+ Expr cur_call = builder_->Normalize(Call(new_call));
+ if (binding->var->IsInstance<DataflowVarNode>()) {
+ // Dataflow var, we emit the rewritten call.
+ ReEmitBinding(binding, cur_call);
+ // update the layout map
+ var_layout_map_[binding->var] = res.value()->output_layouts[0];
+ } else {
+ // Global var (tensor), we rewrite it to initial layout
+ ICHECK(IsNestedTensor(binding->var));
+ if (!NLayoutEqual()(res.value()->output_layouts[0],
InitialNLayout(binding->var))) {
+ Var new_var = builder_->Emit(cur_call);
+ var_layout_map_[new_var] = res.value()->output_layouts[0];
+ cur_call = builder_->Normalize(RewriteExpr(new_var,
InitialNLayout(binding->var)));
+ }
+ ReEmitBinding(binding, cur_call);
+ // update the layout map
+ var_layout_map_[binding->var] = InitialNLayout(binding->var);
+ }
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleNode* val)
final {
+ std::vector<NLayout> input_layout;
+ for (const auto& field : val->fields) {
+ if (binding->var->IsInstance<DataflowVarNode>()) {
+ // Df var: Use the current realized layout to group the tuple;
+ input_layout.push_back(GetNLayout(var_layout_map_, field));
+ } else {
+ // Global var: Use the initial layout to group the tuple;
+ input_layout.push_back(InitialNLayout(field));
+ }
+ }
+ Array<Expr> new_fields = RewriteArgs(val->fields, std::move(input_layout));
+ if (IsNestedTensor(binding->var)) {
+ ReEmitBinding(binding, builder_->Normalize(Tuple(new_fields)));
+ var_layout_map_[binding->var] = input_layout;
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode*
val) final {
+ NLayout input_layout = binding->var->IsInstance<DataflowVarNode>()
+ ? GetNLayout(var_layout_map_, val->tuple)
+ : InitialNLayout(val->tuple);
+ ReEmitBinding(binding, builder_->Normalize(
+ TupleGetItem(RewriteExpr(val->tuple,
input_layout), val->index)));
+ // update the layout map
+ var_layout_map_[binding->var] = input_layout.NestedArray()[val->index];
+ }
+
+ void VisitBinding_(const MatchCastNode* binding) final {
+ if (!binding->var->IsInstance<DataflowVarNode>()) {
+ ExprMutator::VisitBinding_(binding);
+ return;
+ }
+ NLayout from_layout = InitialNLayout(binding->value);
+ NLayout input_layout = GetNLayout(var_layout_map_, binding->value);
+ auto fvisitleaf = [&](const StructInfo& sinfo, std::array<NLayout, 2>
layouts) -> StructInfo {
+ NLayout from = layouts[0], to = layouts[1];
+ if (NLayoutEqual()(from, to)) return sinfo;
+ // If not both from and to are unknown, then none of them can be unknown.
+ ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) &&
+ !NLayoutEqual()(to, LayoutDecision::InitUnknownDim()))
+ << "Cannot convert when exactly one of the layouts is unknown";
+ const TensorStructInfoNode* tsinfo = sinfo.as<TensorStructInfoNode>();
+ ICHECK(tsinfo != nullptr) << "We can not set layout for non-tensor
struct";
+ if (!tsinfo->shape.defined()) return sinfo;
+ const ShapeExprNode* shape = tsinfo->shape.value().as<ShapeExprNode>();
+ if (shape == nullptr) return sinfo;
+ ICHECK_EQ(shape->values.size(), to.LeafValue()->layout.ndim());
+ std::vector<PrimExpr> new_shape;
+ for (size_t i = 0; i < shape->values.size(); ++i) {
+ new_shape.push_back(
+
shape->values[from.LeafValue()->layout.IndexOf(to.LeafValue()->layout[i])]);
+ }
+ return TensorStructInfo(ShapeExpr(new_shape), tsinfo->dtype,
tsinfo->span);
+ };
+ StructInfo new_struct_info = TransformTupleLeaf<LayoutDecision>(
+ binding->struct_info, std::array<NLayout, 2>({from_layout,
input_layout}), fvisitleaf);
+ // re-emit old binding if nothing changes
+ if (new_struct_info.same_as(binding->struct_info)) {
+ builder_->EmitNormalized(GetRef<MatchCast>(binding));
+ } else {
+ Var new_var =
+ builder_->EmitMatchCast(RewriteExpr(binding->value, input_layout),
new_struct_info);
+ var_layout_map_[binding->var] = input_layout;
+ this->var_remap_[binding->var->vid] = new_var;
+ }
+ }
+
+ std::unordered_map<Var, NLayout, ObjectPtrHash, ObjectPtrEqual>
var_layout_map_;
+ Map<String, Array<String>> desired_layouts_;
+}; // namespace relax
+
+DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block,
+ Map<String, Array<String>> desired_layouts) {
+ LayoutConvertMutator mutator(desired_layouts);
+ return Downcast<DataflowBlock>(mutator.VisitBindingBlock(df_block));
+}
+
+namespace transform {
+
+Pass ConvertLayout(Map<String, Array<String>> desired_layouts) {
+ runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)> pass_func =
+ [=](DataflowBlock df_block, IRModule m, PassContext pc) {
+ return Downcast<DataflowBlock>(ConvertLayoutPass(df_block,
desired_layouts));
+ };
+ return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ConvertLayout").set_body_typed(ConvertLayout);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/infer_layout_utils.cc
b/src/relax/transform/infer_layout_utils.cc
new file mode 100644
index 0000000000..e603fb4a1b
--- /dev/null
+++ b/src/relax/transform/infer_layout_utils.cc
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "infer_layout_utils.h"
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+using tir::IterVar;
+using tir::Layout;
+
+Layout TransposeLike(const Layout& input, const Layout& src, const Layout&
dst) {
+ ICHECK(src.ndim() == dst.ndim() && input.ndim() == src.ndim())
+ << "Layouts must have the same size";
+ std::vector<IterVar> axes;
+ for (size_t i = 0; i < src.ndim(); ++i) {
+ axes.push_back(input->axes[src.IndexOf(dst[i])]);
+ }
+ return Layout(axes);
+}
+
+String TransposeStrLike(const String& input, const Layout& src, const Layout&
dst) {
+ ICHECK(src.ndim() == dst.ndim() && input.size() == src.ndim())
+ << "Layouts must have the same size";
+ std::string axes;
+ for (size_t i = 0; i < src.ndim(); ++i) {
+ axes.push_back(input.at(src.IndexOf(dst[i])));
+ }
+ return axes;
+}
+
+int FindAxis(const Layout& dst, int axis) {
+ axis = (axis + dst.ndim()) % dst.ndim();
+ return dst.name().find('A' + axis);
+}
+
+Layout InitialLayout(int ndim) {
+ ICHECK(ndim > 0 && ndim <= 26) << "Only support up to 26 dimensions";
+ return Layout("ABCDEFGHIJKLMNOPQRSTUVWXYZ").SubLayout(0, ndim);
+}
+
+LayoutDecision InitialLayoutDecision(int ndim) {
+ if (ndim == kUnknownNDim) {
+ return LayoutDecision::InitUnknownDim();
+ }
+ ICHECK(ndim >= 0 && ndim <= 26) << "Only support up to 26 dimensions";
+ return Layout("ABCDEFGHIJKLMNOPQRSTUVWXYZ").SubLayout(0, ndim);
+}
+
+NLayout InitialNLayout(const StructInfo& sinfo) {
+ auto fmapleaf = [&](const StructInfo& sinfo) -> NLayout {
+ if (const auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
+ return NLayout(InitialLayoutDecision(tensor_sinfo->ndim));
+ }
+ return LayoutDecision::InitUnknownDim();
+ };
+ return MapToNestedMsg<LayoutDecision>(sinfo, fmapleaf);
+}
+
+NLayout InitialNLayout(const Expr& expr) { return
InitialNLayout(GetStructInfo(expr)); }
+
+LayoutDecision GetLayoutDecision(const VarLayoutMap& var_layout_map, const
Expr& arg) {
+ NLayout nlayout = GetNLayout(var_layout_map, arg);
+ ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << arg;
+ return nlayout.LeafValue();
+}
+
+NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg) {
+ auto fmapleaf = [&](const Expr& expr) -> NLayout {
+ if (const auto* var = expr.as<VarNode>()) {
+ auto it = var_layout_map.find(GetRef<Var>(var));
+ if (it != var_layout_map.end()) {
+ return (*it).second;
+ } else {
+ return InitialNLayout(expr);
+ }
+ } else if (const auto* constant = expr.as<ConstantNode>()) {
+ return InitialLayoutDecision(constant->data.Shape().size());
+ }
+ return LayoutDecision::InitUnknownDim();
+ };
+ return MapToNestedMsg<LayoutDecision>(arg, fmapleaf);
+}
+
+bool NoDesiredLayout(const Call& call, const Map<String, Array<String>>&
desired_layouts) {
+ const OpNode* op_node = call->op.as<OpNode>();
+ if (op_node == nullptr) return false;
+ const auto& it = desired_layouts.find(op_node->name);
+ return it == desired_layouts.end();
+}
+
+LayoutDecision FollowDecision(const LayoutDecision& src, int dst_ndim) {
+ int src_ndim = src->layout.ndim();
+ // broadcast case
+ if (src_ndim == dst_ndim) {
+ return src;
+ } else {
+ ICHECK_LT(src_ndim, dst_ndim) << "Cannot broadcast from " << src_ndim << "
to " << dst_ndim;
+ std::string layout = InitialLayout(dst_ndim - src_ndim).name();
+ for (int i = 0; i < src_ndim; ++i) {
+ layout.push_back(src->layout.name()[i] + dst_ndim - src_ndim);
+ }
+ return LayoutDecision(Layout(layout));
+ }
+}
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/infer_layout_utils.h
b/src/relax/transform/infer_layout_utils.h
new file mode 100644
index 0000000000..2cbbe23ede
--- /dev/null
+++ b/src/relax/transform/infer_layout_utils.h
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file infer_layout_utils.h
+ * \brief Utility functions to alter the layouts of operators or replace
primitive operators with
+ other expressions. This pass can be used for computing convolution in
+ custom layouts or other general weight pre-transformation.
+ */
+
+#ifndef TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_
+#define TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_
+
+#include <tvm/relax/attrs/create.h>
+#include <tvm/relax/attrs/datatype.h>
+#include <tvm/relax/attrs/image.h>
+#include <tvm/relax/attrs/linear_algebra.h>
+#include <tvm/relax/attrs/manipulate.h>
+#include <tvm/relax/attrs/nn.h>
+#include <tvm/relax/attrs/statistical.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/tir/data_layout.h>
+
+#include <array>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+using tir::Layout;
+
+/*!
+ * \brief A layout decision node that holds the layout decision of the tensor.
+ * \param layout The layout of the tensor.
+ */
+class LayoutDecisionNode : public Object {
+ public:
+ /*! \brief The layout decision of the tensor. */
+ Layout layout;
+ /*! \brief Whether the dim of tensor is unknown. */
+ bool is_unknown_dim = false;
+
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("layout", &layout); }
+
+ TVM_DECLARE_BASE_OBJECT_INFO(LayoutDecisionNode, Object);
+
+ static constexpr const char* _type_key = "relax.transform.LayoutDecision";
+};
+
+class LayoutDecision : public ObjectRef {
+ public:
+ LayoutDecision(Layout layout, bool is_unknown_dim = false) { // NOLINT(*)
+ auto n = make_object<LayoutDecisionNode>();
+ n->layout = std::move(layout);
+ n->is_unknown_dim = is_unknown_dim;
+ data_ = n;
+ }
+
+ static LayoutDecision InitUnknownDim() { return
LayoutDecision(Layout::Undef(), true); }
+
+ inline std::string name() const {
+ if (operator->()->is_unknown_dim) {
+ return "unknown_dim";
+ }
+ return operator->()->layout.name();
+ }
+
+ TVM_DEFINE_OBJECT_REF_METHODS(LayoutDecision, ObjectRef, LayoutDecisionNode);
+};
+
+using NLayout = NestedMsg<LayoutDecision>;
+
+/*!
+ * \brief An output structure to hold results from FInferCorrectLayout calls.
+ * \param input_layouts Inferred input layouts.
+ * \param output_layouts Inferred output layouts.
+ * \param new_attrs Updated attributes consistent with inferred layouts.
+ */
+class InferLayoutOutputNode : public Object {
+ public:
+ Array<NLayout> input_layouts;
+ Array<NLayout> output_layouts;
+ Attrs new_attrs;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("input_layouts", &input_layouts);
+ v->Visit("output_layouts", &output_layouts);
+ v->Visit("new_attrs", &new_attrs);
+ }
+
+ TVM_DECLARE_BASE_OBJECT_INFO(InferLayoutOutputNode, Object);
+
+ static constexpr const char* _type_key = "relax.transform.InferLayoutOutput";
+};
+
+class InferLayoutOutput : public ObjectRef {
+ public:
+ explicit InferLayoutOutput(Array<NLayout> input_layouts, Array<NLayout>
output_layouts,
+ Attrs new_attrs) {
+ auto n = make_object<InferLayoutOutputNode>();
+ n->input_layouts = std::move(input_layouts);
+ n->output_layouts = std::move(output_layouts);
+ n->new_attrs = std::move(new_attrs);
+ data_ = n;
+ }
+ TVM_DEFINE_OBJECT_REF_METHODS(InferLayoutOutput, ObjectRef,
InferLayoutOutputNode);
+};
+
+struct NLayoutEqual {
+ bool operator()(const NLayout& a, const NLayout& b) const {
+ auto layout_equal = [](const LayoutDecision& a, const LayoutDecision& b) {
+ if (a.defined() && b.defined()) {
+ return a.name() == b.name();
+ }
+ return a.defined() == b.defined();
+ };
+ return Equal(a, b, layout_equal);
+ }
+};
+
+using VarLayoutMap = Map<Var, NLayout>;
+
+/*!
+ * \brief Layout conversion interface.
+ * \param call The call node.
+ * \param desired_layouts The desired layouts of the operator.
+ * \param var_layout_map The layout of the variables.
+ */
+using FRelaxInferLayout = runtime::TypedPackedFunc<InferLayoutOutput(
+ const Call& call, const Map<String, Array<String>>& desired_layouts,
+ const VarLayoutMap& var_layout_map)>;
+
+/*!
+ * \brief Initialize a layout given the number of dimensions.
+ * \param ndim The number of dimensions.
+ * \return The initialized layout.
+ */
+Layout InitialLayout(int ndim);
+
+/*!
+ * \brief Initialize a layout decision given the number of dimensions.
+ * \param ndim The number of dimensions.
+ * \return The initialized layout decision.
+ */
+LayoutDecision InitialLayoutDecision(int ndim);
+
+/*!
+ * \brief Initialize a nested layout decision given the struct info.
+ * \param sinfo The sinfo.
+ * \return The initialized nested layout decision.
+ */
+NLayout InitialNLayout(const StructInfo& sinfo);
+
+/*!
+ * \brief Initialize a nested layout decision given expression
+ * \param sinfo The expr
+ * \return The initialized nested layout decision.
+ */
+NLayout InitialNLayout(const Expr& expr);
+
+/*!
+ * \brief Transpose the input layout like the src layout to the dst layout.
+ * \param input The input layout.
+ * \param src The source layout.
+ * \param dst The destination layout.
+ * \return The transposed input layout.
+ */
+Layout TransposeLike(const Layout& input, const Layout& src, const Layout&
dst);
+
+/*!
+ * \brief Transpose the input string like the src layout to the dst layout.
+ * \param input The input str.
+ * \param src The source layout.
+ * \param dst The destination layout.
+ * \return The transposed input str.
+ */
+String TransposeStrLike(const String& input, const Layout& src, const Layout&
dst);
+
+/*!
+ * \brief Find axis in the dst layout. 0 represents the first axis, 1
represents the second axis,
+ * etc.
+ * \param dst The destination layout.
+ * \param axis The axis to be found
+ * \return The axis in the dst layout.
+ */
+int FindAxis(const Layout& dst, int axis);
+
+/*!
+ * \brief Get the layout decision of the expr. The expr must be a Tensor.
+ * \param var_layout_map The layout of the variables.
+ * \param arg The expr.
+ * \return The layout decision of the expr.
+ */
+LayoutDecision GetLayoutDecision(const VarLayoutMap& var_layout_map, const
Expr& arg);
+
+/*!
+ * \brief Get the nested layout decision of the expr. The expr must be a
nested Tensor.
+ * \param var_layout_map The layout of the variables.
+ * \param arg The expr.
+ * \return The nested layout decision of the expr.
+ */
+NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg);
+
+/*!
+ * \brief Check if the op is not in the desired layout
+ * \param call The call node contains the op
+ * \param desired_layouts The desired layouts of the operator.
+ * \return True if the op is not in the desired layout.
+ */
+bool NoDesiredLayout(const Call& call, const Map<String, Array<String>>&
desired_layouts);
+
+/*!
+ * \brief Let a tensor with ndim to follow the src layout decision.
+ * \param src The source layout decision.
+ * \param dst_ndim The number of dimensions of the tensor.
+ * \return The layout decision of the tensor.
+ */
+LayoutDecision FollowDecision(const LayoutDecision& src, int dst_ndim);
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_
diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc
new file mode 100644
index 0000000000..9a19115f62
--- /dev/null
+++ b/src/relax/transform/utils.cc
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+bool IsNestedTensor(const StructInfo& sinfo) {
+ return IsNestedTensorConditioned(sinfo, [](const TensorStructInfo& sinfo) {
return true; });
+}
+
+bool IsNestedTensor(const Expr& expr) { return
IsNestedTensor(GetStructInfo(expr)); }
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 463e69d56c..003519cffc 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -155,6 +155,48 @@ bool IsNestedTensorConditioned(const StructInfo& sinfo,
FType f_condition) {
return false;
}
+/*!
+ * \brief Check if the given StructInfo is a nested tensor.
+ * \param sinfo The StructInfo to be checked.
+ * \return true if the given StructInfo is a nested tensor.
+ */
+bool IsNestedTensor(const StructInfo& sinfo);
+
+/*!
+ * \brief Check if the given expr is a nested tensor.
+ * \param expr The expr to be checked.
+ * \return true if the given expr is a nested tensor.
+ */
+bool IsNestedTensor(const Expr& expr);
+
+// TODO(@bohan): implements some postorder function accepts a visitor closure
+class VarReplacer : public ExprMutator {
+ public:
+ using VarMap = std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual>;
+
+ explicit VarReplacer(const VarMap& var_remap) : var_remap_(var_remap) {}
+
+ static Expr Replace(const Expr& expr, const VarMap& var_remap) {
+ VarReplacer replacer(var_remap);
+ return replacer(expr);
+ }
+
+ private:
+ Expr VisitExpr_(const VarNode* op) final {
+ Var var = GetRef<Var>(op);
+ auto it = var_remap_.find(var->vid);
+ return it == var_remap_.end() ? var : it->second;
+ }
+
+ Expr VisitExpr_(const DataflowVarNode* op) final {
+ Var var = GetRef<Var>(op);
+ auto it = var_remap_.find(var->vid);
+ return it == var_remap_.end() ? var : it->second;
+ }
+
+ const VarMap& var_remap_;
+};
+
/*!
* \brief Create a Constant with a scalar
*
diff --git a/tests/python/relax/test_transform_convert_layout.py
b/tests/python/relax/test_transform_convert_layout.py
new file mode 100644
index 0000000000..1b65c0812a
--- /dev/null
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -0,0 +1,1352 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.relax.transform import ConvertLayout, Normalize
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+
+def verify(input, expected):
+ mod = ConvertLayout({"relax.nn.conv2d": ["NHWC", "OHWI"]})(input)
+ mod = Normalize()(mod)
+ print(mod.script())
+ tvm.ir.assert_structural_equal(mod, expected)
+
+
+def test_conv2d():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv)
+ return gv
+
+ verify(Input, Expected)
+
+
+def test_conv2d_onlydim():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4)
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor("float32", ndim=4) = R.nn.conv2d(x, w,
out_dtype="float32")
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32",
ndim=4)
+ ) -> R.Tensor(dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(x,
axes=[0, 2, 3, 1])
+ lv1: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w,
axes=[0, 2, 3, 1])
+ lv2: R.Tensor(dtype="float32", ndim=4) = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv2,
axes=[0, 3, 1, 2])
+ R.output(gv)
+ return gv
+
+ verify(Input, Expected)
+
+
+def test_conv2d_symbolic():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4)
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ N, C, H, W = T.int64(), T.int64(), T.int64(), T.int64()
+ lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32"))
+ gv: R.Tensor("float32", ndim=4) = R.nn.conv2d(lv0, w,
out_dtype="float32")
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32",
ndim=4)
+ ) -> R.Tensor(dtype="float32", ndim=4):
+ N = T.int64()
+ C = T.int64()
+ H = T.int64()
+ W = T.int64()
+ with R.dataflow():
+ lv0: R.Tensor((N, C, H, W), dtype="float32") = R.match_cast(
+ x, R.Tensor((N, C, H, W), dtype="float32")
+ )
+ lv: R.Tensor((N, H, W, C), dtype="float32") =
R.permute_dims(lv0, axes=[0, 2, 3, 1])
+ lv1: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w,
axes=[0, 2, 3, 1])
+ lv2: R.Tensor(dtype="float32", ndim=4) = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv2,
axes=[0, 3, 1, 2])
+ R.output(gv)
+ return gv
+
+ verify(Input, Expected)
+
+
+def test_conv2d_matchcast_bias():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4)
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ lv0: R.Tensor("float32", ndim=4) = R.nn.conv2d(x, w,
out_dtype="float32")
+ N, C, H, W = T.int64(), T.int64(), T.int64(), T.int64()
+ lv1 = R.match_cast(lv0, R.Tensor((N, C, H, W), "float32"))
+ gv = R.add(lv1, w)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32",
ndim=4)
+ ) -> R.Tensor(dtype="float32", ndim=4):
+ N = T.int64()
+ H = T.int64()
+ W = T.int64()
+ C = T.int64()
+ with R.dataflow():
+ lv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(x,
axes=[0, 2, 3, 1])
+ lv1: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w,
axes=[0, 2, 3, 1])
+ lv0: R.Tensor(dtype="float32", ndim=4) = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((N, H, W, C), dtype="float32") = R.match_cast(
+ lv0, R.Tensor((N, H, W, C), dtype="float32")
+ )
+ lv3: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w,
axes=[0, 2, 3, 1])
+ lv4: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv3)
+ gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv4,
axes=[0, 3, 1, 2])
+ R.output(gv)
+ return gv
+
+ verify(Input, Expected)
+
+
+def test_conv2d_relu():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv)
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_relu_conv2d_relu():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ x0: R.Tensor((2, 3, 28, 28), "float32") = R.nn.relu(x)
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+ R.output(gv2)
+ return gv2
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x)
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(
+ x0, axes=[0, 2, 3, 1]
+ )
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv)
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_relu_tanh():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+ gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2)
+ R.output(gv3)
+ return gv3
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv)
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.tanh(gv2)
+ gv3: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv3)
+ return gv3
+
+ verify(Input, Expected)
+
+
+def test_conv2d_add():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"),
+ w: R.Tensor((4, 3, 3, 3), "float32"),
+ bias: R.Tensor((2, 4, 26, 26), "float32"),
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias)
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ bias: R.Tensor((2, 4, 26, 26), dtype="float32"),
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") =
R.permute_dims(
+ bias, axes=[0, 2, 3, 1]
+ )
+ lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.add(gv, lv2)
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv3, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_add_relu_conv2d():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 4, 28, 28), "float32"),
+ w: R.Tensor((4, 4, 3, 3), "float32"),
+ bias: R.Tensor((2, 4, 26, 26), "float32"),
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias)
+ gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2)
+ gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w,
out_dtype="float32")
+ R.output(gv4)
+ return gv4
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 4, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 4, 3, 3), dtype="float32"),
+ bias: R.Tensor((2, 4, 26, 26), dtype="float32"),
+ ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 4), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 4), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") =
R.permute_dims(
+ bias, axes=[0, 2, 3, 1]
+ )
+ gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.add(gv, lv2)
+ gv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv2)
+ lv3: R.Tensor((4, 3, 3, 4), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ lv4: R.Tensor((2, 24, 24, 4), dtype="float32") = R.nn.conv2d(
+ gv3,
+ lv3,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv4: R.Tensor((2, 4, 24, 24), dtype="float32") =
R.permute_dims(
+ lv4, axes=[0, 3, 1, 2]
+ )
+ R.output(gv4)
+ return gv4
+
+ verify(Input, Expected)
+
+
+def test_conv2d_fma_relu_conv2d():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 4, 28, 28), "float32"),
+ w: R.Tensor((4, 4, 3, 3), "float32"),
+ scale: R.Tensor((2, 4, 26, 26), dtype="float32"),
+ bias: R.Tensor((2, 4, 26, 26), "float32"),
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv,
scale, bias)
+ gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2)
+ gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w,
out_dtype="float32")
+ R.output(gv4)
+ return gv4
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 4, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 4, 3, 3), dtype="float32"),
+ scale: R.Tensor((2, 4, 26, 26), dtype="float32"),
+ bias: R.Tensor((2, 4, 26, 26), dtype="float32"),
+ ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 4), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 4), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ gv, axes=[0, 3, 1, 2]
+ )
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.ewise_fma(lv2, scale, bias)
+ gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.relu(gv2)
+ lv3: R.Tensor((2, 26, 26, 4), dtype="float32") =
R.permute_dims(
+ gv3, axes=[0, 2, 3, 1]
+ )
+ lv4: R.Tensor((4, 3, 3, 4), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ lv5: R.Tensor((2, 24, 24, 4), dtype="float32") = R.nn.conv2d(
+ lv3,
+ lv4,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv4: R.Tensor((2, 4, 24, 24), dtype="float32") =
R.permute_dims(
+ lv5, axes=[0, 3, 1, 2]
+ )
+ R.output(gv4)
+ return gv4
+
+ verify(Input, Expected)
+
+
+def test_conv2d_sum():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=2):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[2, 3])
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=2):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv2: R.Tensor((2, 4), dtype="float32") = R.sum(gv, axis=[1,
2], keepdims=False)
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_sum_keepdim():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 1, 1), "float32") = R.sum(gv, axis=[2,
3], keepdims=True)
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 1, 1, 4), dtype="float32") = R.sum(gv,
axis=[1, 2], keepdims=True)
+ gv2: R.Tensor((2, 4, 1, 1), dtype="float32") = R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_transpose():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv,
axes=[3, 2, 1, 0])
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv2: R.Tensor((26, 26, 4, 2), dtype="float32") =
R.permute_dims(
+ gv, axes=[2, 1, 3, 0]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_expand_dims():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=6):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") =
R.expand_dims(gv, axis=(-3, 1))
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=6):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 1, 26, 1, 26, 4), dtype="float32") =
R.expand_dims(
+ gv, axis=[-3, 1]
+ )
+ gv2: R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 1, 5, 3, 2, 4]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_expand_dims_squeeze():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") =
R.expand_dims(gv, axis=(-3, 1))
+ gv3: R.Tensor((2, 4, 26, 26), "float32") = R.squeeze(gv2,
axis=[1, 3])
+ R.output(gv3)
+ return gv3
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv2: R.Tensor((2, 1, 26, 1, 26, 4), dtype="float32") =
R.expand_dims(
+ gv, axis=[-3, 1]
+ )
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") =
R.squeeze(gv2, axis=[1, 3])
+ gv3: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv3)
+ return gv3
+
+ verify(Input, Expected)
+
+
+def test_conv2d_strided_slice():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice(
+ gv, begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4],
axes=[1, 2, 3]
+ )
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 9, 7, 2), dtype="float32") = R.strided_slice(
+ gv, axes=[3, 1, 2], begin=[0, 0, 0], end=[4, 26, 26],
strides=[2, 3, 4]
+ )
+ gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_relu_concat():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+ gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2),
axis=1)
+ R.output(gv3)
+ return gv3
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv)
+ lv2: R.Tensor((2, 26, 26, 8), dtype="float32") = R.concat((gv,
gv2), axis=3)
+ gv3: R.Tensor((2, 8, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv3)
+ return gv3
+
+ verify(Input, Expected)
+
+
+def test_conv2d_relu_concat_split():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3,
3), "float32")):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+ gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2),
axis=1)
+ gv4 = R.split(gv3, indices_or_sections=2, axis=1)
+ R.output(gv4)
+ return gv4
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv)
+ gv3: R.Tensor((2, 26, 26, 8), dtype="float32") = R.concat((gv,
gv2), axis=3)
+ gv4: R.Tuple(
+ R.Tensor((2, 26, 26, 4), dtype="float32"),
+ R.Tensor((2, 26, 26, 4), dtype="float32"),
+ ) = R.split(gv3, indices_or_sections=2, axis=3)
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = gv4[0]
+ lv3: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ lv4: R.Tensor((2, 26, 26, 4), dtype="float32") = gv4[1]
+ lv5: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv4, axes=[0, 3, 1, 2]
+ )
+ gv5 = (lv3, lv5)
+ R.output(gv5)
+ return gv5
+
+ verify(Input, Expected)
+
+
+def test_conv2d_maxpool2d():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2 = R.nn.max_pool2d(
+ gv,
+ pool_size=[2, 2],
+ strides=[2, 2],
+ padding=[0, 0],
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 13, 13, 4), dtype="float32") =
R.nn.max_pool2d(
+ gv,
+ pool_size=[2, 2],
+ strides=[2, 2],
+ dilation=[1, 1],
+ padding=[0, 0, 0, 0],
+ ceil_mode=False,
+ layout="NHWC",
+ out_layout="NHWC",
+ )
+ gv2: R.Tensor((2, 4, 13, 13), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_avgpool2d():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2 = R.nn.adaptive_avg_pool2d(gv, output_size=[13, 13],
layout="NCHW")
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 13, 13, 4), dtype="float32") =
R.nn.adaptive_avg_pool2d(
+ gv, output_size=[13, 13], layout="NHWC", out_layout="NHWC"
+ )
+ gv2: R.Tensor((2, 4, 13, 13), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_softmax():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2 = R.nn.softmax(gv, axis=1)
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") =
R.nn.softmax(gv, axis=3)
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_batchnorm():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"),
+ w: R.Tensor((4, 3, 3, 3), "float32"),
+ gamma: R.Tensor((4,), dtype="float32"),
+ beta: R.Tensor((4,), dtype="float32"),
+ moving_mean: R.Tensor((4,), dtype="float32"),
+ moving_var: R.Tensor((4,), dtype="float32"),
+ ):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tuple(
+ R.Tensor((2, 4, 26, 26), dtype="float32"),
+ R.Tensor((4,), dtype="float32"),
+ R.Tensor((4,), dtype="float32"),
+ ) = R.nn.batch_norm(gv, gamma, beta, moving_mean, moving_var,
axis=1)
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ gamma: R.Tensor((4,), dtype="float32"),
+ beta: R.Tensor((4,), dtype="float32"),
+ moving_mean: R.Tensor((4,), dtype="float32"),
+ moving_var: R.Tensor((4,), dtype="float32"),
+ ):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ gv2: R.Tuple(
+ R.Tensor((2, 26, 26, 4), dtype="float32"),
+ R.Tensor((4,), dtype="float32"),
+ R.Tensor((4,), dtype="float32"),
+ ) = R.nn.batch_norm(
+ gv,
+ gamma,
+ beta,
+ moving_mean,
+ moving_var,
+ axis=3,
+ epsilon=1.0000000000000001e-05,
+ center=True,
+ scale=True,
+ )
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = gv2[0]
+ lv3: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ lv4: R.Tensor((4,), dtype="float32") = gv2[1]
+ lv5: R.Tensor((4,), dtype="float32") = gv2[2]
+ gv3 = (lv3, lv4, lv5)
+ R.output(gv3)
+ return gv3
+
+ verify(Input, Expected)
+
+
+def test_conv2d_layernorm():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"),
+ w: R.Tensor((4, 3, 3, 3), "float32"),
+ gamma: R.Tensor((26, 26), dtype="float32"),
+ beta: R.Tensor((26, 26), dtype="float32"),
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.nn.layer_norm(
+ gv, gamma, beta, axes=[-2, -1]
+ )
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ gamma: R.Tensor((26, 26), dtype="float32"),
+ beta: R.Tensor((26, 26), dtype="float32"),
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") =
R.nn.layer_norm(
+ gv,
+ gamma,
+ beta,
+ axes=[1, 2],
+ epsilon=1.0000000000000001e-05,
+ center=True,
+ scale=True,
+ )
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_resize2d():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2 = R.image.resize2d(gv, (52, 52), layout="NCHW")
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 52, 52, 4), dtype="float32") =
R.image.resize2d(
+ gv,
+ (52, 52),
+ roi=[T.float32(0), T.float32(0), T.float32(0),
T.float32(0)],
+ layout="NHWC",
+ method="linear",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="round",
+ cubic_alpha=-0.5,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="void",
+ )
+ gv2: R.Tensor((2, 4, 52, 52), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_conv2d_unknown_bias_dim():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"),
+ w: R.Tensor((4, 3, 3, 3), "float32"),
+ w2: R.Tensor(dtype="float32"),
+ ) -> R.Tensor(None, "float32"):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2 = w2 + gv
+ R.output(gv2)
+ return gv2
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ w2: R.Tensor(dtype="float32"),
+ ) -> R.Tensor(dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ gv, axes=[0, 3, 1, 2]
+ )
+ gv2: R.Tensor(dtype="float32") = R.add(w2, lv2)
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+def test_binary_broadcast():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"),
+ w: R.Tensor((4, 3, 3, 3), "float32"),
+ bias: R.Tensor((26, 26), "float32"),
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias)
+ R.output(gv2)
+ return gv2
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ bias: R.Tensor((26, 26), dtype="float32"),
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ gv, axes=[0, 3, 1, 2]
+ )
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(lv2,
bias)
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()