This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new fec44f605e [Unity][Transform] Automatic Layout Conversion (#14257)
fec44f605e is described below

commit fec44f605edae8d96b03a571cf1a786843cc6c82
Author: Bohan Hou <[email protected]>
AuthorDate: Sun Mar 19 10:55:57 2023 -0400

    [Unity][Transform] Automatic Layout Conversion (#14257)
    
    This PR adds a new pass ConvertLayout that converts the layout of conv2d 
operators (NCHW->NHWC) and tries to propagate this conversion when appropriate.
    
    see 
https://github.com/spectrometerHBH/tvm/blob/convert-layout/src/relax/transform/convert_layout.cc#L39
 on how this pass works
    
    It works on the op level, which is in parallel with the ongoing Layout 
Transformation effort that works on TIR level in the community.
---
 include/tvm/relax/nested_msg.h                     |   44 +
 include/tvm/relax/transform.h                      |    9 +
 python/tvm/relax/transform/transform.py            |   17 +
 src/relax/op/image/resize.cc                       |   16 +-
 src/relax/op/nn/convolution.cc                     |   45 +-
 src/relax/op/nn/nn.cc                              |  101 +-
 src/relax/op/nn/pooling.cc                         |   43 +-
 src/relax/op/op_common.cc                          |    8 +
 src/relax/op/op_common.h                           |   22 +-
 src/relax/op/tensor/binary.cc                      |   22 +
 src/relax/op/tensor/binary.h                       |    3 +-
 src/relax/op/tensor/datatype.cc                    |    3 +-
 src/relax/op/tensor/index.cc                       |   23 +-
 src/relax/op/tensor/manipulate.cc                  |  189 ++-
 src/relax/op/tensor/statistical.cc                 |   52 +
 src/relax/op/tensor/statistical.h                  |   27 +-
 src/relax/op/tensor/ternary.cc                     |   18 +-
 src/relax/transform/convert_layout.cc              |  309 +++++
 src/relax/transform/infer_layout_utils.cc          |  126 ++
 src/relax/transform/infer_layout_utils.h           |  244 ++++
 src/relax/transform/utils.cc                       |   32 +
 src/relax/transform/utils.h                        |   42 +
 .../python/relax/test_transform_convert_layout.py  | 1352 ++++++++++++++++++++
 23 files changed, 2711 insertions(+), 36 deletions(-)

diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h
index 93fc9a36c5..0564c26687 100644
--- a/include/tvm/relax/nested_msg.h
+++ b/include/tvm/relax/nested_msg.h
@@ -531,6 +531,50 @@ Expr TransformTupleLeaf(Expr expr, 
std::array<NestedMsg<T>, N> msgs, FType ftran
   }
 }
 
+/*!
+ * \brief Recursively transform the tuple structure in sinfo and msgs along 
with it.
+ *
+ * This function will call ftransleaf for each leaf sinfo in sinfo.
+ * This function will throw an error if the nesting structure in msg does not
+ * match the tuple nesting structure in sinfo.
+ *
+ * \param sinfo The input sinfo to be transform. 
+ * \param msgs The input messages to guide the transformation.
+ * \param ftransleaf with signature ftransleaf(StructInfo, 
Array<NestedMsg<T>>)->StructInfo
+ * \tparam T the content type of nested msg
+ * \tparam N the number of messages
+ * \tparam FType The visit function type.
+ */
+template <typename T, std::size_t N, typename FType>
+StructInfo TransformTupleLeaf(StructInfo sinfo, std::array<NestedMsg<T>, N> 
msgs,
+                              FType ftransleaf) {
+  if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
+    std::array<Array<NestedMsg<T>>, N> msg_arrays;
+    for (size_t i = 0; i < N; ++i) {
+      ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
+      msg_arrays[i] = msgs[i].NestedArray();
+    }
+    bool same = true;
+    Array<StructInfo> fields;
+    fields.reserve(tuple->fields.size());
+    for (size_t i = 0; i < tuple->fields.size(); ++i) {
+      StructInfo field = tuple->fields[i];
+      std::array<NestedMsg<T>, N> sub_msgs;
+      for (size_t j = 0; j < N; ++j) {
+        sub_msgs[j] = msg_arrays[j][i];
+      }
+      fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), 
ftransleaf));
+      same &= (fields.back().same_as(field));
+    }
+    return same ? sinfo : TupleStructInfo(fields);
+  } else {
+    for (const auto& msg : msgs) {
+      ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
+    }
+    return ftransleaf(sinfo, msgs);
+  }
+}
+
 }  // namespace relax
 }  // namespace tvm
 #endif  // TVM_RELAX_NESTED_MSG_H_
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index e0fe226e83..ead8b0c31e 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -369,6 +369,7 @@ TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, 
ObjectRef>>> target_opt
  * \return The Pass.
  */
 TVM_DLL Pass SimplifyNormInference();
+
 /*!
  * \brief Returns a pass which replaces PrimFuncs which have matching 
kOperatorName attribute in \p
  * op_impl_map, with replacement PrimFunc that could possibly have different 
layouts on i/o
@@ -383,6 +384,14 @@ TVM_DLL Pass SimplifyNormInference();
  */
 TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
                          const Map<String, Array<tir::IndexMap>>& 
op_buffer_transforms);
+
+/*!
+ * \brief Layout conversion pass.
+ * \param desired_layouts The desired layouts for some operators.
+ * \return The Pass.
+ */
+TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
+
 }  // namespace transform
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 0df29dc093..c10d0130c1 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -640,6 +640,23 @@ def AlterOpImpl(
     return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms)  # type: 
ignore
 
 
+def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> 
tvm.ir.transform.Pass:
+    """Automatic layout conversion pass.
+    Parameters
+    ----------
+    desired_layouts : Dict[str, List[str]]
+        The desired layout of conv2d ops is a map from the name of the op to 
the desired layout
+        of the desired feature map, weight and output. For example, if we want 
to convert the
+        layout of conv2d from NCHW to NHWC, we can set the desired layout of 
conv2d to be
+        {"conv2d": ["NHWC", "OHWI"]}.
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass for layout conversion.
+    """
+    return _ffi_api.ConvertLayout(desired_layouts)  # type: ignore
+
+
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass."""
 
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 2711b3cc45..de6eec6236 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -102,12 +102,26 @@ StructInfo InferStructInfoResize2D(const Call& call, 
const BlockBuilder& ctx) {
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
 }
 
+InferLayoutOutput InferLayoutResize2d(const Call& call,
+                                      const Map<String, Array<String>>& 
desired_layouts,
+                                      const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* attrs = call->attrs.as<Resize2DAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<Resize2DAttrs> new_attrs = make_object<Resize2DAttrs>(*attrs);
+  new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), 
layout->layout).name();
+  return InferLayoutOutput({layout, InitialNLayout(call->args[1])}, {layout}, 
Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.image.resize2d")
     .set_attrs_type<Resize2DAttrs>()
     .set_num_inputs(2)
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("size", "Shape", "The output image shape.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize2D);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize2D)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutResize2d);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index 8dc3c9696f..f356876620 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -135,12 +135,55 @@ StructInfo InferStructInfoConv2d(const Call& call, const 
BlockBuilder& ctx) {
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
 }
 
+InferLayoutOutput InferLayoutConv2d(const Call& call,
+                                    const Map<String, Array<String>>& 
desired_layouts,
+                                    const VarLayoutMap& var_layout_map) {
+  const auto& it = desired_layouts.find("relax.nn.conv2d");
+  const auto* attrs = call->attrs.as<Conv2DAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision data_layout, weight_layout, output_layout;
+  ObjectPtr<Conv2DAttrs> new_attrs = make_object<Conv2DAttrs>(*attrs);
+
+  if (it != desired_layouts.end()) {
+    // We have a desired layout for conv2d.
+    Layout desired_data_layout = (*it).second[0];
+    Layout desired_weight_layout = (*it).second[1];
+    Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] 
: (*it).second[0];
+    ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) 
<< "Axis swap only";
+    ICHECK_EQ(desired_weight_layout.ndim(), 
desired_weight_layout.ndim_primal())
+        << "Axis swap only";
+    ICHECK_EQ(desired_output_layout.ndim(), 
desired_output_layout.ndim_primal())
+        << "Axis swap only";
+    data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, 
desired_data_layout);
+    weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, 
desired_weight_layout);
+    output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, 
desired_output_layout);
+    new_attrs->data_layout = (*it).second[0];
+    new_attrs->kernel_layout = (*it).second[1];
+    new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : 
(*it).second[0];
+  } else {
+    // We don't have a desired layout for conv2d.
+    // We can just propagate the layout from the input.
+    data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
+    weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
+    output_layout = data_layout;
+    new_attrs->data_layout =
+        TransposeLike(attrs->data_layout, InitialLayout(4), 
data_layout->layout).name();
+    new_attrs->kernel_layout =
+        TransposeLike(attrs->kernel_layout, InitialLayout(4), 
weight_layout->layout).name();
+    new_attrs->out_layout =
+        TransposeLike(attrs->out_layout, InitialLayout(4), 
output_layout->layout).name();
+  }
+  return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, 
Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.nn.conv2d")
     .set_num_inputs(2)
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("weight", "Tensor", "The weight tensor.")
     .set_attrs_type<Conv2DAttrs>()
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2d);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2d)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv2d);
 
 /* relax.nn.conv2d_transpose */
 TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 430d2268ce..6bce51ca50 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -62,11 +62,25 @@ StructInfo InferStructInfoSoftmax(const Call& call, const 
BlockBuilder& ctx) {
   return data_sinfo;
 }
 
+InferLayoutOutput InferLayoutSoftmax(const Call& call,
+                                     const Map<String, Array<String>>& 
desired_layouts,
+                                     const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* attrs = call->attrs.as<SoftmaxAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<SoftmaxAttrs> new_attrs = make_object<SoftmaxAttrs>(*attrs);
+  new_attrs->axis = FindAxis(layout->layout, attrs->axis);
+  return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.nn.softmax")
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor.")
     .set_attrs_type<SoftmaxAttrs>()
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSoftmax);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSoftmax)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSoftmax);
 
 /* relax.nn.log_softmax */
 Expr log_softmax(Expr data, int axis) {
@@ -188,6 +202,28 @@ StructInfo InferStructInfoBatchNorm(const Call& call, 
const BlockBuilder& ctx) {
   }
 }
 
+InferLayoutOutput InferLayoutBatchNorm(const Call& call,
+                                       const Map<String, Array<String>>& 
desired_layouts,
+                                       const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  std::vector<NLayout> initial_layouts;
+  for (size_t i = 0; i < 5; ++i) {
+    const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
+    ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+    ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+    initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
+  }
+  const auto* attrs = call->attrs.as<BatchNormAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<BatchNormAttrs> new_attrs = make_object<BatchNormAttrs>(*attrs);
+  new_attrs->axis = FindAxis(layout->layout, attrs->axis);
+  return InferLayoutOutput(
+      {layout, initial_layouts[1], initial_layouts[2], initial_layouts[3], 
initial_layouts[4]},
+      {{layout, initial_layouts[3], initial_layouts[4]}}, Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.nn.batch_norm")
     .set_attrs_type<BatchNormAttrs>()
     .set_num_inputs(5)
@@ -196,7 +232,8 @@ TVM_REGISTER_OP("relax.nn.batch_norm")
     .add_argument("beta", "Tensor", "The beta offset factor.")
     .add_argument("moving_mean", "Tensor", "Running mean of input.")
     .add_argument("moving_var", "Tensor", "Running variance of input.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBatchNorm);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBatchNorm)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutBatchNorm);
 
 /* relax.nn.layer_norm */
 TVM_REGISTER_NODE_TYPE(LayerNormAttrs);
@@ -225,13 +262,39 @@ StructInfo InferStructInfoLayerNorm(const Call& call, 
const BlockBuilder& ctx) {
                        : input_sinfo[0];
 }
 
+InferLayoutOutput InferLayoutLayerNorm(const Call& call,
+                                       const Map<String, Array<String>>& 
desired_layouts,
+                                       const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  std::vector<NLayout> initial_layouts;
+  for (size_t i = 0; i < 3; ++i) {
+    const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
+    ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+    ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+    initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
+  }
+  const auto* attrs = call->attrs.as<LayerNormAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<LayerNormAttrs> new_attrs = make_object<LayerNormAttrs>(*attrs);
+  std::vector<Integer> new_axis;
+  for (const auto& axis : attrs->axes) {
+    new_axis.push_back(FindAxis(layout->layout, axis->value));
+  }
+  new_attrs->axes = std::move(new_axis);
+  return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, 
{layout},
+                           Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.nn.layer_norm")
     .set_attrs_type<LayerNormAttrs>()
     .set_num_inputs(3)
     .add_argument("data", "Tensor", "Input to which batch_norm will be 
applied.")
     .add_argument("gamma", "Tensor", "The gamma scale factor.")
     .add_argument("beta", "Tensor", "The beta offset factor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutLayerNorm);
 
 /* relax.nn.group_norm */
 TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
@@ -308,13 +371,40 @@ StructInfo InferStructInfoGroupNorm(const Call& call, 
const BlockBuilder& ctx) {
   return data_sinfo;
 }
 
+InferLayoutOutput InferLayoutGroupNorm(const Call& call,
+                                       const Map<String, Array<String>>& 
desired_layouts,
+                                       const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  std::vector<NLayout> initial_layouts;
+  for (size_t i = 0; i < 3; ++i) {
+    const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
+    ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+    ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+    initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
+  }
+  const auto* attrs = call->attrs.as<GroupNormAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<GroupNormAttrs> new_attrs = make_object<GroupNormAttrs>(*attrs);
+  std::vector<Integer> new_axes;
+  for (const auto& axis : attrs->axes) {
+    new_axes.push_back(FindAxis(layout->layout, axis->value));
+  }
+  new_attrs->axes = std::move(new_axes);
+  new_attrs->channel_axis = FindAxis(layout->layout, attrs->channel_axis);
+  return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, 
{layout},
+                           Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.nn.group_norm")
     .set_attrs_type<GroupNormAttrs>()
     .set_num_inputs(3)
     .add_argument("data", "Tensor", "Input to which batch_norm will be 
applied.")
     .add_argument("gamma", "Tensor", "The gamma scale factor.")
     .add_argument("beta", "Tensor", "The beta offset factor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutGroupNorm);
 
 /* relax.nn.dropout */
 TVM_REGISTER_NODE_TYPE(DropoutAttrs);
@@ -338,7 +428,8 @@ TVM_REGISTER_OP("relax.nn.dropout")
     .set_attrs_type<DropoutAttrs>()
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "Input to which dropout will be applied.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDropout);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDropout)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise);
 
 /* relax.nn.cross_entropy_with_logits */
 StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& 
ctx) {
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index 61001ce678..be0a794dee 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -117,11 +117,29 @@ StructInfo InferStructInfoPool2D(const Call& call, const 
BlockBuilder& ctx) {
   return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
 }
 
+InferLayoutOutput InferLayoutPool2d(const Call& call,
+                                    const Map<String, Array<String>>& 
desired_layouts,
+                                    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout";
+  const auto* attrs = call->attrs.as<Pool2DAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<Pool2DAttrs> new_attrs = make_object<Pool2DAttrs>(*attrs);
+  new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), 
layout->layout).name();
+  new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), 
layout->layout).name();
+  return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.nn.max_pool2d")
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor")
     .set_attrs_type<Pool2DAttrs>()
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d);
 
 Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
                 Array<IntImm> dilation, bool ceil_mode, String layout,
@@ -136,7 +154,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d")
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor")
     .set_attrs_type<Pool2DAttrs>()
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool2d);
 
 /* relax.nn.adaptive_avg_pool2d */
 TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs);
@@ -196,11 +215,29 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& 
call, const BlockBuilder
   return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
 }
 
+InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call,
+                                               const Map<String, 
Array<String>>& desired_layouts,
+                                               const VarLayoutMap& 
var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout";
+  const auto* attrs = call->attrs.as<AdaptivePool2DAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<AdaptivePool2DAttrs> new_attrs = 
make_object<AdaptivePool2DAttrs>(*attrs);
+  new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), 
layout->layout).name();
+  new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), 
layout->layout).name();
+  return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
     .set_attrs_type<AdaptivePool2DAttrs>()
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor")
-    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoAdaptiveAvgPool2D);
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoAdaptiveAvgPool2D)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", 
InferLayoutAdaptiveAvgPool2D);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc
index c82c325d9b..0421c957ca 100644
--- a/src/relax/op/op_common.cc
+++ b/src/relax/op/op_common.cc
@@ -117,5 +117,13 @@ std::vector<int> NormalizeAxes(const Call& call, const 
BlockBuilder& ctx, int nd
   return axes_non_neg;
 }
 
+InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
+                                        const Map<String, Array<String>>& 
desired_layouts,
+                                        const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs));
+}
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 29e02946c6..ece4c4a321 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -34,6 +34,8 @@
 #include <utility>
 #include <vector>
 
+#include "../transform/infer_layout_utils.h"
+
 namespace tvm {
 namespace relax {
 
@@ -68,10 +70,11 @@ inline TensorStructInfo GetUnaryInputTensorStructInfo(const 
Call& call, const Bl
  * \param OpRegName The name of operator to register. The name passed in will
  * be prepended with a prefix "relax." as the identifier string in the 
operator registry.
  */
-#define RELAX_REGISTER_UNARY_OP(OpRegName) \
-  TVM_REGISTER_OP("relax." OpRegName)      \
-      .set_num_inputs(1)                   \
-      .add_argument("x", "Tensor", "The input tensor.")
+#define RELAX_REGISTER_UNARY_OP(OpRegName)              \
+  TVM_REGISTER_OP("relax." OpRegName)                   \
+      .set_num_inputs(1)                                \
+      .add_argument("x", "Tensor", "The input tensor.") \
+      .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
 
 /*!
  * \brief Quick helper macro to expose a make-function to construct the 
operator.
@@ -151,6 +154,17 @@ StructInfo InferStructInfoUnaryArith(const Call& call, 
const BlockBuilder& ctx)
       call, ctx, [](const TensorStructInfo& input_sinfo) { return 
input_sinfo->dtype; });
 }
 
+/*!
+ * \brief Layout infer util for unary elementwise ops. It will simply take the 
layout of the input.
+ * \param call The context Call to the operator.
+ * \param desired_layouts The desired layouts of certain ops.
+ * \param var_layout_map The layout of vars.
+ * \return The inferred layout result.
+ */
+InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
+                                        const Map<String, Array<String>>& 
desired_layouts,
+                                        const VarLayoutMap& var_layout_map);
+
 /*!
  * \brief Infer the output datatype for binary arithmetic operators.
  * \param call The context Call to the operator.
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index 7e8480ee16..30cd748308 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -78,6 +78,28 @@ StructInfo InferStructInfoBroadcastCMP(const Call& call, 
const BlockBuilder& ctx
          const TensorStructInfo& x2_sinfo) { return DataType::Bool(); });
 }
 
+InferLayoutOutput InferLayoutBinaryEwise(const Call& call,
+                                         const Map<String, Array<String>>& 
desired_layouts,
+                                         const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[0]);
+  LayoutDecision layout2 = GetLayoutDecision(var_layout_map, call->args[1]);
+
+  auto* x1_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  auto* x2_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+
+  ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim())
+      << "Unknown dim tensors should not be handled by this function";
+
+  if (x1_sinfo->ndim <= x2_sinfo->ndim) {
+    LayoutDecision out_layout = FollowDecision(layout1, x2_sinfo->ndim);
+    return InferLayoutOutput({layout1, out_layout}, {out_layout}, 
Attrs(call->attrs));
+  } else {
+    LayoutDecision out_layout = FollowDecision(layout2, x1_sinfo->ndim);
+    return InferLayoutOutput({out_layout, layout2}, {out_layout}, 
Attrs(call->attrs));
+  }
+}
+
 /***************** Arithmetic operators *****************/
 
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add);
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index 0a48e727e6..197110c000 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -46,7 +46,8 @@ namespace relax {
   TVM_REGISTER_OP("relax." #OpName)                                \
       .set_num_inputs(2)                                           \
       .add_argument("x1", "Tensor", "The first input tensor.")     \
-      .add_argument("x2", "Tensor", "The second input tensor.")
+      .add_argument("x2", "Tensor", "The second input tensor.")    \
+      .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutBinaryEwise)
 
 #define RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(OpName)             \
   RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr<FInferStructInfo>( \
diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc
index 0c647aa866..349a54ee4d 100644
--- a/src/relax/op/tensor/datatype.cc
+++ b/src/relax/op/tensor/datatype.cc
@@ -54,7 +54,8 @@ TVM_REGISTER_OP("relax.astype")
     .set_attrs_type<AstypeAttrs>()
     .set_num_inputs(1)
     .add_argument("x", "Tensor", "The input tensor")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAstype);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAstype)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 246abef908..29f668ccf3 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -185,11 +185,32 @@ StructInfo InferStructInfoStridedSlice(const Call& call, 
const BlockBuilder& ctx
   return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
 }
 
+InferLayoutOutput InferLayoutStridedSlice(const Call& call,
+                                          const Map<String, Array<String>>& 
desired_layouts,
+                                          const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  const auto* attrs = call->attrs.as<StridedSliceAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+  LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  std::vector<Integer> new_axes;
+  for (const auto& axis : attrs->axes) {
+    new_axes.push_back(FindAxis(existing_layout->layout, axis->value));
+  }
+  ObjectPtr<StridedSliceAttrs> new_attrs = 
make_object<StridedSliceAttrs>(*attrs);
+  new_attrs->axes = std::move(new_axes);
+  return InferLayoutOutput({existing_layout}, {existing_layout}, 
Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.strided_slice")
     .set_attrs_type<StridedSliceAttrs>()
     .set_num_inputs(1)
     .add_argument("x", "Tensor", "The source tensor to be sliced.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoStridedSlice);
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoStridedSlice)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index c7bf051302..49f745608f 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -26,6 +26,7 @@
 
 #include <algorithm>
 #include <numeric>
+#include <string>
 #include <utility>
 #include <vector>
 
@@ -272,11 +273,35 @@ StructInfo InferStructInfoConcat(const Call& call, const 
BlockBuilder& ctx) {
   }
 }
 
+InferLayoutOutput InferLayoutConcat(const Call& call,
+                                    const Map<String, Array<String>>& 
desired_layouts,
+                                    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  const auto* attrs = call->attrs.as<ConcatAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  NLayout nlayout = GetNLayout(var_layout_map, call->args[0]);
+  ICHECK(nlayout.IsNested());
+  ICHECK(nlayout.NestedArray()[0].IsLeaf());
+
+  int n_tensor = nlayout.NestedArray().size();
+  LayoutDecision layout = nlayout.NestedArray()[0].LeafValue();
+  Array<NLayout> input_layouts, output_layouts;
+  for (int i = 0; i < n_tensor; ++i) {
+    input_layouts.push_back(layout);
+  }
+  output_layouts.push_back(layout);
+  ObjectPtr<ConcatAttrs> new_attrs = make_object<ConcatAttrs>(*attrs);
+  new_attrs->axis = Integer(FindAxis(layout->layout, 
attrs->axis.value_or(0)->value));
+  return InferLayoutOutput({NLayout(input_layouts)}, output_layouts, 
Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.concat")
     .set_attrs_type<ConcatAttrs>()
     .set_num_inputs(1)
     .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConcat);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConcat)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConcat);
 
 /* relax.expand_dims */
 TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
@@ -330,11 +355,49 @@ StructInfo InferStructInfoExpandDims(const Call& call, 
const BlockBuilder& ctx)
   return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
 }
 
+InferLayoutOutput InferLayoutExpandDims(const Call& call,
+                                        const Map<String, Array<String>>& 
desired_layouts,
+                                        const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* attrs = call->attrs.as<ExpandDimsAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+
+  LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  int ndim = tensor_sinfo->ndim;
+  int n_new_dim = attrs->axis.size();
+  int output_ndim = ndim + n_new_dim;
+  std::vector<bool> is_new_dim(output_ndim, false);
+  for (const auto& axis : attrs->axis) {
+    is_new_dim[(axis->value + output_ndim) % output_ndim] = true;
+  }
+  std::string new_layout;
+  for (int i = 0; i < output_ndim; ++i) {
+    if (!is_new_dim[i]) {
+      new_layout.push_back('A' + i);
+    }
+  }
+  new_layout = TransposeStrLike(new_layout, InitialLayout(ndim), 
existing_layout->layout);
+  std::string output_layout;
+  for (int i = 0, j = 0; i < output_ndim; ++i) {
+    if (is_new_dim[i]) {
+      output_layout.push_back('A' + i);
+    } else {
+      output_layout.push_back(new_layout.at(j++));
+    }
+  }
+  return InferLayoutOutput({existing_layout}, 
{LayoutDecision(Layout(output_layout))},
+                           Attrs(call->attrs));
+}
+
 TVM_REGISTER_OP("relax.expand_dims")
     .set_num_inputs(1)
     .set_attrs_type<ExpandDimsAttrs>()
     .add_argument("x", "Tensor", "The input tensor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoExpandDims);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoExpandDims)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutExpandDims);
 
 // Helper function for flatten and reshape.
 PrimExpr ComputeShapeProduct(const Array<PrimExpr>& shape_values) {
@@ -505,11 +568,49 @@ StructInfo InferStructInfoPermuteDims(const Call& call, 
const BlockBuilder& ctx)
   return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype);
 }
 
+InferLayoutOutput InferLayoutPermuteDims(const Call& call,
+                                         const Map<String, Array<String>>& 
desired_layouts,
+                                         const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  const auto* attrs = call->attrs.as<PermuteDimsAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+  int ndim = tensor_sinfo->ndim;
+
+  LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  Array<Integer> order;
+  if (attrs->axes.defined()) {
+    order = attrs->axes.value();
+  } else {
+    order.reserve(ndim);
+    for (int i = 0; i < ndim; ++i) {
+      order.push_back(Integer(ndim - i - 1));
+    }
+  }
+  std::string order_str;
+  for (const auto& axis : order) {
+    order_str.push_back(axis->value + 'A');
+  }
+  String new_axes =
+      TransposeStrLike(InitialLayout(ndim).name(), existing_layout->layout, 
order_str);
+  Array<Integer> new_order;
+  for (size_t i = 0; i < new_axes.size(); ++i) {
+    new_order.push_back(Integer(new_axes.at(i) - 'A'));
+  }
+  ObjectPtr<PermuteDimsAttrs> new_attrs = 
make_object<PermuteDimsAttrs>(*attrs);
+  new_attrs->axes = new_order;
+  return InferLayoutOutput({existing_layout}, {InitialLayoutDecision(ndim)}, 
Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.permute_dims")
     .set_attrs_type<PermuteDimsAttrs>()
     .set_num_inputs(1)
     .add_argument("x", "Tensor", "The input tensor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoPermuteDims);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPermuteDims)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPermuteDims);
 
 /* relax.reshape */
 Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
@@ -758,11 +859,33 @@ StructInfo InferStructInfoSplit(const Call& call, const 
BlockBuilder& ctx) {
   throw;
 }
 
+InferLayoutOutput InferLayoutSplit(const Call& call,
+                                   const Map<String, Array<String>>& 
desired_layouts,
+                                   const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  const auto* attrs = call->attrs.as<SplitAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+
+  LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  ObjectPtr<SplitAttrs> new_attrs = make_object<SplitAttrs>(*attrs);
+  new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis);
+  StructInfo out_sinfo = InferStructInfoSplit(call, 
BlockBuilder::Create(IRModule()));
+  const auto* out_tuple = out_sinfo.as<TupleStructInfoNode>();
+  ICHECK(out_tuple != nullptr) << "Invalid Call";
+  NLayout tuple_layouts(Array<NLayout>(out_tuple->fields.size(), 
existing_layout));
+  return InferLayoutOutput({existing_layout}, {tuple_layouts}, 
Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.split")
     .set_attrs_type<SplitAttrs>()
     .set_num_inputs(1)
     .add_argument("x", "Tensor", "The input tensor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSplit);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSplit)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSplit);
 
 /* relax.squeeze */
 TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
@@ -857,11 +980,67 @@ StructInfo InferStructInfoSqueeze(const Call& call, const 
BlockBuilder& ctx) {
   }
 }
 
+InferLayoutOutput InferLayoutSqueeze(const Call& call,
+                                     const Map<String, Array<String>>& 
desired_layouts,
+                                     const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  const auto* attrs = call->attrs.as<SqueezeAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+  ICHECK(tensor_sinfo->shape.defined()) << "Only support static shape for now";
+  int ndim = tensor_sinfo->ndim;
+  const auto* shape = tensor_sinfo->shape.as<ShapeExprNode>();
+  ICHECK(shape != nullptr) << "Only support static shape for now";
+
+  Array<Integer> axis;
+  if (attrs->axis.defined()) {
+    axis = attrs->axis.value();
+  } else {
+    axis.reserve(ndim);
+    for (int i = 0; i < ndim; ++i) {
+      if (tir::is_one(shape->values[i])) {
+        axis.push_back(Integer(i));
+      }
+    }
+  }
+
+  std::string axis_str(ndim, '0');
+  for (const auto& iter : axis) {
+    axis_str[iter->value] = '1';
+  }
+  for (int i = 0, j = 0; i < ndim; ++i) {
+    if (axis_str[i] != '1') {
+      axis_str[i] = 'A' + j++;
+    }
+  }
+
+  LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), 
existing_layout->layout);
+  Array<Integer> new_axis;
+  for (size_t i = 0; i < new_axis_str.size(); ++i) {
+    if (new_axis_str.at(i) == '1') {
+      new_axis.push_back(Integer(i));
+    }
+  }
+  std::string output_layout = new_axis_str;
+  output_layout.erase(std::remove(output_layout.begin(), output_layout.end(), 
'1'),
+                      output_layout.end());
+
+  ObjectPtr<SqueezeAttrs> new_attrs = make_object<SqueezeAttrs>(*attrs);
+  new_attrs->axis = new_axis;
+  return InferLayoutOutput({existing_layout}, 
{LayoutDecision(Layout(output_layout))},
+                           Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.squeeze")
     .set_num_inputs(1)
     .set_attrs_type<SqueezeAttrs>()
     .add_argument("x", "Tensor", "The input tensor.")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSqueeze);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSqueeze)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutSqueeze);
 
 void CheckCollapseShape(const Call& call, const BlockBuilder& ctx,
                         const Array<PrimExpr>& data_shape, const 
Array<PrimExpr>& target_shape) {
diff --git a/src/relax/op/tensor/statistical.cc 
b/src/relax/op/tensor/statistical.cc
index 41b99fbe36..4de8e3dd63 100644
--- a/src/relax/op/tensor/statistical.cc
+++ b/src/relax/op/tensor/statistical.cc
@@ -24,6 +24,7 @@
 
 #include "statistical.h"
 
+#include <string>
 #include <vector>
 
 namespace tvm {
@@ -82,6 +83,57 @@ StructInfo InferStructInfoStatistical(const Call& call, 
const BlockBuilder& ctx)
   return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
 }
 
+InferLayoutOutput InferLayoutStatistical(const Call& call,
+                                         const Map<String, Array<String>>& 
desired_layouts,
+                                         const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  const auto* attrs = call->attrs.as<StatisticalAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+  int ndim = tensor_sinfo->ndim;
+
+  Array<Integer> axis;
+  if (attrs->axis.defined()) {
+    axis = attrs->axis.value();
+  } else {
+    axis.reserve(ndim);
+    for (int i = 0; i < ndim; ++i) {
+      axis.push_back(Integer(i));
+    }
+  }
+
+  std::string axis_str(ndim, '0');
+  for (const auto& iter : axis) {
+    axis_str[iter->value] = '1';
+  }
+  for (int i = 0, j = 0; i < ndim; ++i) {
+    if (axis_str[i] != '1') {
+      axis_str[i] = 'A' + j++;
+    }
+  }
+
+  LayoutDecision exisiting_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), 
exisiting_layout->layout);
+  Array<Integer> new_axis;
+  for (size_t i = 0; i < new_axis_str.size(); ++i) {
+    if (new_axis_str.at(i) == '1') {
+      new_axis.push_back(Integer(i));
+    }
+  }
+  std::string output_layout = new_axis_str;
+  output_layout.erase(std::remove(output_layout.begin(), output_layout.end(), 
'1'),
+                      output_layout.end());
+
+  ObjectPtr<StatisticalAttrs> new_attrs = 
make_object<StatisticalAttrs>(*attrs);
+  new_attrs->axis = new_axis;
+  return InferLayoutOutput({exisiting_layout},
+                           {attrs->keepdims ? exisiting_layout : 
Layout(output_layout)},
+                           Attrs(new_attrs));
+}
+
 TVM_REGISTER_NODE_TYPE(StatisticalAttrs);
 
 RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max);
diff --git a/src/relax/op/tensor/statistical.h 
b/src/relax/op/tensor/statistical.h
index 7d322d1129..0adeb82259 100644
--- a/src/relax/op/tensor/statistical.h
+++ b/src/relax/op/tensor/statistical.h
@@ -42,19 +42,20 @@ namespace relax {
  *  1. be prepended with a prefix "relax.op." as the FFI identifier string for 
the make function,
  *  2. be prepended with a prefix "relax." as the identifier string in the 
operator registry.
  */
-#define RELAX_REGISTER_STATISTICAL_OP_INTERFACE(OpName)                  \
-  Expr OpName(Expr x, Optional<Array<Integer>> axis, bool keepdims) {    \
-    ObjectPtr<StatisticalAttrs> attrs = make_object<StatisticalAttrs>(); \
-    attrs->axis = std::move(axis);                                       \
-    attrs->keepdims = keepdims;                                          \
-    static const Op& op = Op::Get("relax." #OpName);                     \
-    return Call(op, {std::move(x)}, Attrs{attrs}, {});                   \
-  }                                                                      \
-  TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName);       \
-  TVM_REGISTER_OP("relax." #OpName)                                      \
-      .set_num_inputs(1)                                                 \
-      .add_argument("x", "Tensor", "The input data tensor")              \
-      .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoStatistical)
+#define RELAX_REGISTER_STATISTICAL_OP_INTERFACE(OpName)                        
   \
+  Expr OpName(Expr x, Optional<Array<Integer>> axis, bool keepdims) {          
   \
+    ObjectPtr<StatisticalAttrs> attrs = make_object<StatisticalAttrs>();       
   \
+    attrs->axis = std::move(axis);                                             
   \
+    attrs->keepdims = keepdims;                                                
   \
+    static const Op& op = Op::Get("relax." #OpName);                           
   \
+    return Call(op, {std::move(x)}, Attrs{attrs}, {});                         
   \
+  }                                                                            
   \
+  TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName);             
   \
+  TVM_REGISTER_OP("relax." #OpName)                                            
   \
+      .set_num_inputs(1)                                                       
   \
+      .add_argument("x", "Tensor", "The input data tensor")                    
   \
+      .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoStatistical) \
+      .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStatistical)
 
 /*!
  * \brief Computes the maximum value of tensor elements over given axes.
diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc
index 8820c07afd..93652f43ef 100644
--- a/src/relax/op/tensor/ternary.cc
+++ b/src/relax/op/tensor/ternary.cc
@@ -90,12 +90,28 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const 
BlockBuilder& ctx) {
   return TensorStructInfo(output_dtype, ndim);
 }
 
+InferLayoutOutput InferLayoutEwiseFMA(const Call& call,
+                                      const Map<String, Array<String>>& 
desired_layouts,
+                                      const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  LayoutDecision layout0 = GetLayoutDecision(var_layout_map, call->args[0]);
+  LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[1]);
+  LayoutDecision layout2 = GetLayoutDecision(var_layout_map, call->args[2]);
+  LayoutDecision layout = layout0;
+  if (NLayoutEqual()(layout1, layout2)) {
+    layout = layout1;
+  }
+  return InferLayoutOutput({layout, layout, layout}, {layout}, 
Attrs(call->attrs));
+}
+
 TVM_REGISTER_OP("relax.ewise_fma")
     .set_num_inputs(3)
     .add_argument("x1", "Tensor", "The left hand operand of the 
multiplication")
     .add_argument("x2", "Tensor", "The right hand operand of the 
multiplication")
     .add_argument("x3", "Tensor", "The operand of the addition")
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEwiseFMA);
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEwiseFMA)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutEwiseFMA);
 
 Expr ewise_fma(Expr x1, Expr x2, Expr x3) {
   static const Op& op = Op::Get("relax.ewise_fma");
diff --git a/src/relax/transform/convert_layout.cc 
b/src/relax/transform/convert_layout.cc
new file mode 100644
index 0000000000..4f36cfbc0f
--- /dev/null
+++ b/src/relax/transform/convert_layout.cc
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/convert_layout.cc
+ * \brief Automatic layout conversion pass, especially for axis swapping.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/relax/transform.h>
+
+#include "../op/tensor/manipulate.h"
+#include "infer_layout_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+using tir::Layout;
+
+/*!
+ * \brief Main logic to convert the layout of conv2d. Other ops
+ * can adapt to such layout conversion following conv2d accordingly.
+ *
+ * Structurally speaking, a Relax function is composed of a series of 
VarBinding and
+ * MatchCast. And a specific class of VarBindings is the basic unit we want to 
rewrite.
+ * Formally, they are of the form:
+ *
+ * var = Call(Op, [args], attrs)
+ *
+ * where Op is a specific op we want to rewrite, and attrs is the attributes 
of the op.
+ * var and args are all exprs with type Tensor or Tuple of Tensors. They might
+ * be vars, constants, or Tuple of vars and constants.
+ *
+ * We register the layout inference function for each op (FRelaxInferLayout), 
which accepts the
+ * current call, the desired layout of conv2d ops, and the layout map of 
previous vars. The result
+ * of the layout inference function is contained in an InferLayoutOutput 
object, which contains 3
+ * fields: input_layouts, output_layouts, and attr, which represents the 
expected input layout,
+ * output_layout and converted attrs of the new op call.
+ *
+ * The rewrite pass does the rewriting in a single forward pass, where for 
each Call(Op),
+ * we collect the current Layout of each input var, and let the InferLayout 
function to infer the
+ * desired layout of the output. The rewriter will use these info to convert
+ * the layout of inputs and attrs of the op call, and note down the new layout 
of the output.
+ *
+ * The desired layout of conv2d ops is a map from the name of the op to the 
desired layout of the
+ * desired feature map, weight and output. For example, if we want to convert 
the layout of conv2d
+ * from NCHW to NHWC, we can set the desired layout of conv2d to be {"conv2d": 
["NHWC", "OHWI"]}.
+ *
+ * The way we represent the layout of a var is a NLayout object, which is a 
nested tuple of Layout.
+ * The incoming layout of the module will be set as the default layout (We use 
ABCD... as the
+ * default) Note that for operators like conv, pool, people typically use NHWC 
to refer to the axes.
+ * But to be generic and support more operators, we use ABCD... to refer to 
the axes.
+ *
+ * Note that currently the layout conversion of conv2d only support axis 
swapping, such as NCHW to
+ * NWHC. Packed layout such as NCHW to NCHW4c is not supported now.
+ */
+class LayoutConvertMutator : public ExprMutator {
+ public:
+  explicit LayoutConvertMutator(const Map<String, Array<String>>& 
desired_layouts)
+      : desired_layouts_(desired_layouts) {}
+
+ private:
+  Array<Integer> LayoutToIntegers(const Layout& layout) {
+    Array<Integer> ret;
+    LayoutDecision src = InitialLayoutDecision(layout.ndim());
+    for (size_t i = 0; i < layout.ndim(); ++i) {
+      ret.push_back(Integer(src->layout.IndexOf(layout[i])));
+    }
+    return ret;
+  }
+
+  Expr RewriteExpr(const Expr& expr, const NLayout& to) {
+    auto fvisitleaf = [&](const Expr& expr, std::array<NLayout, 2> layouts) -> 
Expr {
+      NLayout from = layouts[0], to = layouts[1];
+      if (NLayoutEqual()(from, to)) return expr;
+      // If not both from and to are unknown, then none of them can be unknown.
+      ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) &&
+             !NLayoutEqual()(to, LayoutDecision::InitUnknownDim()))
+          << "Cannot convert when exactly one of the layouts is unknown";
+      const auto* tensor = GetStructInfoAs<TensorStructInfoNode>(expr);
+      ICHECK(tensor != nullptr) << "Expect a tensor, but got: " << expr;
+      Layout axes = TransposeLike(InitialLayoutDecision(tensor->ndim)->layout,
+                                  from.LeafValue()->layout, 
to.LeafValue()->layout);
+      return permute_dims(expr, LayoutToIntegers(axes));
+    };
+    return TransformTupleLeaf<LayoutDecision>(
+        VarReplacer::Replace(expr, var_remap_),
+        std::array<NLayout, 2>({GetNLayout(var_layout_map_, expr), to}), 
fvisitleaf);
+  }
+
+  Array<Expr> RewriteArgs(const Array<Expr>& args, const Array<NLayout>& to) {
+    ICHECK(args.size() == to.size());
+    std::vector<Expr> new_args;
+    for (size_t i = 0; i < args.size(); ++i) {
+      new_args.push_back(RewriteExpr(args[i], to[i]));
+    }
+    return std::move(new_args);
+  }
+
+  void VisitBinding(const Binding& binding) final {
+    // Emit the binding
+    ExprMutator::VisitBinding(binding);
+    // The layout is default to be initial if not rewritten.
+    if (var_layout_map_.find(binding->var) == var_layout_map_.end()) {
+      var_layout_map_[binding->var] = InitialNLayout(binding->var);
+    }
+  }
+
+  Expr VisitVars_(const Var& var) {
+    // We encounter a var use outside of inferrable regions, we rewrite it to 
initial layout.
+    return RewriteExpr(var, InitialNLayout(var));
+  }
+
+  Expr VisitExpr_(const VarNode* op) final { return 
VisitVars_(GetRef<Var>(op)); }
+
+  Expr VisitExpr_(const DataflowVarNode* op) final { return 
VisitVars_(GetRef<Var>(op)); }
+
+  bool HasUnknownDimTensor(const NLayout& nlayout) {
+    bool find = false;
+    auto fvisit = [&](const LayoutDecision& layout) {
+      find = find | (NLayoutEqual()(layout, LayoutDecision::InitUnknownDim()));
+    };
+    ForEachLeaf<LayoutDecision>(nlayout, fvisit);
+    return find;
+  }
+
+  bool HasUnknownDimTensor(const Array<Expr>& args) {
+    for (const auto& arg : args) {
+      if (IsNestedTensor(arg)) {
+        if (HasUnknownDimTensor(GetNLayout(var_layout_map_, arg))) {
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
+  Optional<InferLayoutOutput> GetInferLayoutInfo(const CallNode* call_node,
+                                                 const Map<String, 
Array<String>>& desired_layouts,
+                                                 const VarLayoutMap& 
var_layout_map) {
+    const OpNode* op_node = call_node->op.as<OpNode>();
+    if (op_node == nullptr) return NullOpt;
+    Op op = Downcast<Op>(GetRef<Op>(op_node));
+    const auto attr_map = 
Op::GetAttrMap<FRelaxInferLayout>("FRelaxInferLayout");
+    if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) {
+      // If the op has FRelaxInferLayout, and all the input tensors have known 
ndim
+      FRelaxInferLayout f = attr_map[op];
+      return f(GetRef<Call>(call_node), desired_layouts, var_layout_map);
+    } else {
+      // Otherwise, we use the default policy.
+      return NullOpt;
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) 
final {
+    Optional<InferLayoutOutput> res =
+        GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_);
+    ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node);
+    new_call->struct_info_ = NullOpt;
+    if (!res.defined() ||
+        (!IsNestedTensor(binding->var) && 
!binding->var->IsInstance<DataflowVarNode>())) {
+      // Default policy: use the initial layout.
+      // When we don't have the infer layout info, or it's a non-tensor global 
var binding.
+      std::vector<NLayout> input_layout;
+      for (const auto& arg : call_node->args) {
+        input_layout.push_back(InitialNLayout(arg));
+      }
+      Array<Expr> new_args = RewriteArgs(call_node->args, 
std::move(input_layout));
+      new_call->args = std::move(new_args);
+      ReEmitBinding(binding, builder_->Normalize(Call(new_call)));
+      // update the layout map
+      var_layout_map_[binding->var] = InitialNLayout(binding->var);
+    } else {
+      // Convert the layout according to the inferred layout output.
+      Array<Expr> new_args = RewriteArgs(call_node->args, 
res.value()->input_layouts);
+      new_call->args = std::move(new_args);
+      new_call->attrs = std::move(res.value()->new_attrs);
+      Expr cur_call = builder_->Normalize(Call(new_call));
+      if (binding->var->IsInstance<DataflowVarNode>()) {
+        // Dataflow var, we emit the rewritten call.
+        ReEmitBinding(binding, cur_call);
+        // update the layout map
+        var_layout_map_[binding->var] = res.value()->output_layouts[0];
+      } else {
+        // Global var (tensor), we rewrite it to initial layout
+        ICHECK(IsNestedTensor(binding->var));
+        if (!NLayoutEqual()(res.value()->output_layouts[0], 
InitialNLayout(binding->var))) {
+          Var new_var = builder_->Emit(cur_call);
+          var_layout_map_[new_var] = res.value()->output_layouts[0];
+          cur_call = builder_->Normalize(RewriteExpr(new_var, 
InitialNLayout(binding->var)));
+        }
+        ReEmitBinding(binding, cur_call);
+        // update the layout map
+        var_layout_map_[binding->var] = InitialNLayout(binding->var);
+      }
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) 
final {
+    std::vector<NLayout> input_layout;
+    for (const auto& field : val->fields) {
+      if (binding->var->IsInstance<DataflowVarNode>()) {
+        // Df var: Use the current realized layout to group the tuple;
+        input_layout.push_back(GetNLayout(var_layout_map_, field));
+      } else {
+        // Global var: Use the initial layout to group the tuple;
+        input_layout.push_back(InitialNLayout(field));
+      }
+    }
+    Array<Expr> new_fields = RewriteArgs(val->fields, std::move(input_layout));
+    if (IsNestedTensor(binding->var)) {
+      ReEmitBinding(binding, builder_->Normalize(Tuple(new_fields)));
+      var_layout_map_[binding->var] = input_layout;
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* 
val) final {
+    NLayout input_layout = binding->var->IsInstance<DataflowVarNode>()
+                               ? GetNLayout(var_layout_map_, val->tuple)
+                               : InitialNLayout(val->tuple);
+    ReEmitBinding(binding, builder_->Normalize(
+                               TupleGetItem(RewriteExpr(val->tuple, 
input_layout), val->index)));
+    // update the layout map
+    var_layout_map_[binding->var] = input_layout.NestedArray()[val->index];
+  }
+
+  void VisitBinding_(const MatchCastNode* binding) final {
+    if (!binding->var->IsInstance<DataflowVarNode>()) {
+      ExprMutator::VisitBinding_(binding);
+      return;
+    }
+    NLayout from_layout = InitialNLayout(binding->value);
+    NLayout input_layout = GetNLayout(var_layout_map_, binding->value);
+    auto fvisitleaf = [&](const StructInfo& sinfo, std::array<NLayout, 2> 
layouts) -> StructInfo {
+      NLayout from = layouts[0], to = layouts[1];
+      if (NLayoutEqual()(from, to)) return sinfo;
+      // If not both from and to are unknown, then none of them can be unknown.
+      ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) &&
+             !NLayoutEqual()(to, LayoutDecision::InitUnknownDim()))
+          << "Cannot convert when exactly one of the layouts is unknown";
+      const TensorStructInfoNode* tsinfo = sinfo.as<TensorStructInfoNode>();
+      ICHECK(tsinfo != nullptr) << "We can not set layout for non-tensor 
struct";
+      if (!tsinfo->shape.defined()) return sinfo;
+      const ShapeExprNode* shape = tsinfo->shape.value().as<ShapeExprNode>();
+      if (shape == nullptr) return sinfo;
+      ICHECK_EQ(shape->values.size(), to.LeafValue()->layout.ndim());
+      std::vector<PrimExpr> new_shape;
+      for (size_t i = 0; i < shape->values.size(); ++i) {
+        new_shape.push_back(
+            
shape->values[from.LeafValue()->layout.IndexOf(to.LeafValue()->layout[i])]);
+      }
+      return TensorStructInfo(ShapeExpr(new_shape), tsinfo->dtype, 
tsinfo->span);
+    };
+    StructInfo new_struct_info = TransformTupleLeaf<LayoutDecision>(
+        binding->struct_info, std::array<NLayout, 2>({from_layout, 
input_layout}), fvisitleaf);
+    // re-emit old binding if nothing changes
+    if (new_struct_info.same_as(binding->struct_info)) {
+      builder_->EmitNormalized(GetRef<MatchCast>(binding));
+    } else {
+      Var new_var =
+          builder_->EmitMatchCast(RewriteExpr(binding->value, input_layout), 
new_struct_info);
+      var_layout_map_[binding->var] = input_layout;
+      this->var_remap_[binding->var->vid] = new_var;
+    }
+  }
+
+  std::unordered_map<Var, NLayout, ObjectPtrHash, ObjectPtrEqual> 
var_layout_map_;
+  Map<String, Array<String>> desired_layouts_;
+};  // namespace relax
+
+DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block,
+                                Map<String, Array<String>> desired_layouts) {
+  LayoutConvertMutator mutator(desired_layouts);
+  return Downcast<DataflowBlock>(mutator.VisitBindingBlock(df_block));
+}
+
+namespace transform {
+
+Pass ConvertLayout(Map<String, Array<String>> desired_layouts) {
+  runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)> pass_func =
+      [=](DataflowBlock df_block, IRModule m, PassContext pc) {
+        return Downcast<DataflowBlock>(ConvertLayoutPass(df_block, 
desired_layouts));
+      };
+  return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ConvertLayout").set_body_typed(ConvertLayout);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/infer_layout_utils.cc 
b/src/relax/transform/infer_layout_utils.cc
new file mode 100644
index 0000000000..e603fb4a1b
--- /dev/null
+++ b/src/relax/transform/infer_layout_utils.cc
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "infer_layout_utils.h"
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+using tir::IterVar;
+using tir::Layout;
+
+Layout TransposeLike(const Layout& input, const Layout& src, const Layout& 
dst) {
+  ICHECK(src.ndim() == dst.ndim() && input.ndim() == src.ndim())
+      << "Layouts must have the same size";
+  std::vector<IterVar> axes;
+  for (size_t i = 0; i < src.ndim(); ++i) {
+    axes.push_back(input->axes[src.IndexOf(dst[i])]);
+  }
+  return Layout(axes);
+}
+
+String TransposeStrLike(const String& input, const Layout& src, const Layout& 
dst) {
+  ICHECK(src.ndim() == dst.ndim() && input.size() == src.ndim())
+      << "Layouts must have the same size";
+  std::string axes;
+  for (size_t i = 0; i < src.ndim(); ++i) {
+    axes.push_back(input.at(src.IndexOf(dst[i])));
+  }
+  return axes;
+}
+
+int FindAxis(const Layout& dst, int axis) {
+  axis = (axis + dst.ndim()) % dst.ndim();
+  return dst.name().find('A' + axis);
+}
+
+Layout InitialLayout(int ndim) {
+  ICHECK(ndim > 0 && ndim <= 26) << "Only support up to 26 dimensions";
+  return Layout("ABCDEFGHIJKLMNOPQRSTUVWXYZ").SubLayout(0, ndim);
+}
+
+LayoutDecision InitialLayoutDecision(int ndim) {
+  if (ndim == kUnknownNDim) {
+    return LayoutDecision::InitUnknownDim();
+  }
+  ICHECK(ndim >= 0 && ndim <= 26) << "Only support up to 26 dimensions";
+  return Layout("ABCDEFGHIJKLMNOPQRSTUVWXYZ").SubLayout(0, ndim);
+}
+
+NLayout InitialNLayout(const StructInfo& sinfo) {
+  auto fmapleaf = [&](const StructInfo& sinfo) -> NLayout {
+    if (const auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
+      return NLayout(InitialLayoutDecision(tensor_sinfo->ndim));
+    }
+    return LayoutDecision::InitUnknownDim();
+  };
+  return MapToNestedMsg<LayoutDecision>(sinfo, fmapleaf);
+}
+
+NLayout InitialNLayout(const Expr& expr) { return 
InitialNLayout(GetStructInfo(expr)); }
+
+LayoutDecision GetLayoutDecision(const VarLayoutMap& var_layout_map, const 
Expr& arg) {
+  NLayout nlayout = GetNLayout(var_layout_map, arg);
+  ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << arg;
+  return nlayout.LeafValue();
+}
+
+NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg) {
+  auto fmapleaf = [&](const Expr& expr) -> NLayout {
+    if (const auto* var = expr.as<VarNode>()) {
+      auto it = var_layout_map.find(GetRef<Var>(var));
+      if (it != var_layout_map.end()) {
+        return (*it).second;
+      } else {
+        return InitialNLayout(expr);
+      }
+    } else if (const auto* constant = expr.as<ConstantNode>()) {
+      return InitialLayoutDecision(constant->data.Shape().size());
+    }
+    return LayoutDecision::InitUnknownDim();
+  };
+  return MapToNestedMsg<LayoutDecision>(arg, fmapleaf);
+}
+
+bool NoDesiredLayout(const Call& call, const Map<String, Array<String>>& 
desired_layouts) {
+  const OpNode* op_node = call->op.as<OpNode>();
+  if (op_node == nullptr) return false;
+  const auto& it = desired_layouts.find(op_node->name);
+  return it == desired_layouts.end();
+}
+
+LayoutDecision FollowDecision(const LayoutDecision& src, int dst_ndim) {
+  int src_ndim = src->layout.ndim();
+  // broadcast case
+  if (src_ndim == dst_ndim) {
+    return src;
+  } else {
+    ICHECK_LT(src_ndim, dst_ndim) << "Cannot broadcast from " << src_ndim << " 
to " << dst_ndim;
+    std::string layout = InitialLayout(dst_ndim - src_ndim).name();
+    for (int i = 0; i < src_ndim; ++i) {
+      layout.push_back(src->layout.name()[i] + dst_ndim - src_ndim);
+    }
+    return LayoutDecision(Layout(layout));
+  }
+}
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/infer_layout_utils.h 
b/src/relax/transform/infer_layout_utils.h
new file mode 100644
index 0000000000..2cbbe23ede
--- /dev/null
+++ b/src/relax/transform/infer_layout_utils.h
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file infer_layout_utils.h
+ * \brief Utility functions to alter the layouts of operators or replace 
primitive operators with
+          other expressions. This pass can be used for computing convolution in
+          custom layouts or other general weight pre-transformation.
+ */
+
+#ifndef TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_
+#define TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_
+
+#include <tvm/relax/attrs/create.h>
+#include <tvm/relax/attrs/datatype.h>
+#include <tvm/relax/attrs/image.h>
+#include <tvm/relax/attrs/linear_algebra.h>
+#include <tvm/relax/attrs/manipulate.h>
+#include <tvm/relax/attrs/nn.h>
+#include <tvm/relax/attrs/statistical.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/tir/data_layout.h>
+
+#include <array>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+using tir::Layout;
+
+/*!
+ * \brief A layout decision node that holds the layout decision of the tensor.
+ * \param layout The layout of the tensor.
+ */
+class LayoutDecisionNode : public Object {
+ public:
+  /*! \brief The layout decision of the tensor. */
+  Layout layout;
+  /*! \brief Whether the dim of tensor is unknown. */
+  bool is_unknown_dim = false;
+
+  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("layout", &layout); }
+
+  TVM_DECLARE_BASE_OBJECT_INFO(LayoutDecisionNode, Object);
+
+  static constexpr const char* _type_key = "relax.transform.LayoutDecision";
+};
+
+class LayoutDecision : public ObjectRef {
+ public:
+  LayoutDecision(Layout layout, bool is_unknown_dim = false) {  // NOLINT(*)
+    auto n = make_object<LayoutDecisionNode>();
+    n->layout = std::move(layout);
+    n->is_unknown_dim = is_unknown_dim;
+    data_ = n;
+  }
+
+  static LayoutDecision InitUnknownDim() { return 
LayoutDecision(Layout::Undef(), true); }
+
+  inline std::string name() const {
+    if (operator->()->is_unknown_dim) {
+      return "unknown_dim";
+    }
+    return operator->()->layout.name();
+  }
+
+  TVM_DEFINE_OBJECT_REF_METHODS(LayoutDecision, ObjectRef, LayoutDecisionNode);
+};
+
+using NLayout = NestedMsg<LayoutDecision>;
+
+/*!
+ * \brief An output structure to hold results from FInferCorrectLayout calls.
+ * \param input_layouts Inferred input layouts.
+ * \param output_layouts Inferred output layouts.
+ * \param new_attrs Updated attributes consistent with inferred layouts.
+ */
+class InferLayoutOutputNode : public Object {
+ public:
+  Array<NLayout> input_layouts;
+  Array<NLayout> output_layouts;
+  Attrs new_attrs;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("input_layouts", &input_layouts);
+    v->Visit("output_layouts", &output_layouts);
+    v->Visit("new_attrs", &new_attrs);
+  }
+
+  TVM_DECLARE_BASE_OBJECT_INFO(InferLayoutOutputNode, Object);
+
+  static constexpr const char* _type_key = "relax.transform.InferLayoutOutput";
+};
+
+class InferLayoutOutput : public ObjectRef {
+ public:
+  explicit InferLayoutOutput(Array<NLayout> input_layouts, Array<NLayout> 
output_layouts,
+                             Attrs new_attrs) {
+    auto n = make_object<InferLayoutOutputNode>();
+    n->input_layouts = std::move(input_layouts);
+    n->output_layouts = std::move(output_layouts);
+    n->new_attrs = std::move(new_attrs);
+    data_ = n;
+  }
+  TVM_DEFINE_OBJECT_REF_METHODS(InferLayoutOutput, ObjectRef, 
InferLayoutOutputNode);
+};
+
+struct NLayoutEqual {
+  bool operator()(const NLayout& a, const NLayout& b) const {
+    auto layout_equal = [](const LayoutDecision& a, const LayoutDecision& b) {
+      if (a.defined() && b.defined()) {
+        return a.name() == b.name();
+      }
+      return a.defined() == b.defined();
+    };
+    return Equal(a, b, layout_equal);
+  }
+};
+
+using VarLayoutMap = Map<Var, NLayout>;
+
+/*!
+ * \brief Layout conversion interface.
+ * \param call The call node.
+ * \param desired_layouts The desired layouts of the operator.
+ * \param var_layout_map The layout of the variables.
+ */
+using FRelaxInferLayout = runtime::TypedPackedFunc<InferLayoutOutput(
+    const Call& call, const Map<String, Array<String>>& desired_layouts,
+    const VarLayoutMap& var_layout_map)>;
+
+/*!
+ * \brief Initialize a layout given the number of dimensions.
+ * \param ndim The number of dimensions.
+ * \return The initialized layout.
+ */
+Layout InitialLayout(int ndim);
+
+/*!
+ * \brief Initialize a layout decision given the number of dimensions.
+ * \param ndim The number of dimensions.
+ * \return The initialized layout decision.
+ */
+LayoutDecision InitialLayoutDecision(int ndim);
+
+/*!
+ * \brief Initialize a nested layout decision given the struct info.
+ * \param sinfo The sinfo.
+ * \return The initialized nested layout decision.
+ */
+NLayout InitialNLayout(const StructInfo& sinfo);
+
+/*!
+ * \brief Initialize a nested layout decision given expression
+ * \param sinfo The expr
+ * \return The initialized nested layout decision.
+ */
+NLayout InitialNLayout(const Expr& expr);
+
+/*!
+ * \brief Transpose the input layout  like the src layout to the dst layout.
+ * \param input The input layout.
+ * \param src The source layout.
+ * \param dst The destination layout.
+ * \return The transposed input layout.
+ */
+Layout TransposeLike(const Layout& input, const Layout& src, const Layout& 
dst);
+
+/*!
+ * \brief Transpose the input string like the src layout to the dst layout.
+ * \param input The input str.
+ * \param src The source layout.
+ * \param dst The destination layout.
+ * \return The transposed input str.
+ */
+String TransposeStrLike(const String& input, const Layout& src, const Layout& 
dst);
+
+/*!
+ * \brief Find axis in the dst layout. 0 represents the first axis, 1 
represents the second axis,
+ * etc.
+ * \param dst The destination layout.
+ * \param axis The axis to be found
+ * \return The axis in the dst layout.
+ */
+int FindAxis(const Layout& dst, int axis);
+
+/*!
+ * \brief Get the layout decision of the expr. The expr must be a Tensor.
+ * \param var_layout_map The layout of the variables.
+ * \param arg The expr.
+ * \return The layout decision of the expr.
+ */
+LayoutDecision GetLayoutDecision(const VarLayoutMap& var_layout_map, const 
Expr& arg);
+
+/*!
+ * \brief Get the nested layout decision of the expr. The expr must be a 
nested Tensor.
+ * \param var_layout_map The layout of the variables.
+ * \param arg The expr.
+ * \return The nested layout decision of the expr.
+ */
+NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg);
+
+/*!
+ * \brief Check if the op is not in the desired layout
+ * \param call The call node contains the op
+ * \param desired_layouts The desired layouts of the operator.
+ * \return True if the op is not in the desired layout.
+ */
+bool NoDesiredLayout(const Call& call, const Map<String, Array<String>>& 
desired_layouts);
+
+/*!
+ * \brief Let a tensor with ndim to follow the src layout decision.
+ * \param src The source layout decision.
+ * \param dst_ndim The number of dimensions of the tensor.
+ * \return The layout decision of the tensor.
+ */
+LayoutDecision FollowDecision(const LayoutDecision& src, int dst_ndim);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_
diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc
new file mode 100644
index 0000000000..9a19115f62
--- /dev/null
+++ b/src/relax/transform/utils.cc
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+bool IsNestedTensor(const StructInfo& sinfo) {
+  return IsNestedTensorConditioned(sinfo, [](const TensorStructInfo& sinfo) { 
return true; });
+}
+
+bool IsNestedTensor(const Expr& expr) { return 
IsNestedTensor(GetStructInfo(expr)); }
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 463e69d56c..003519cffc 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -155,6 +155,48 @@ bool IsNestedTensorConditioned(const StructInfo& sinfo, 
FType f_condition) {
   return false;
 }
 
+/*!
+ * \brief Check if the given StructInfo is a nested tensor.
+ * \param sinfo The StructInfo to be checked.
+ * \return true if the given StructInfo is a nested tensor.
+ */
+bool IsNestedTensor(const StructInfo& sinfo);
+
+/*!
+ * \brief Check if the given expr is a nested tensor.
+ * \param expr The expr to be checked.
+ * \return true if the given expr is a nested tensor.
+ */
+bool IsNestedTensor(const Expr& expr);
+
+// TODO(@bohan): implements some postorder function accepts a visitor closure
+class VarReplacer : public ExprMutator {
+ public:
+  using VarMap = std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual>;
+
+  explicit VarReplacer(const VarMap& var_remap) : var_remap_(var_remap) {}
+
+  static Expr Replace(const Expr& expr, const VarMap& var_remap) {
+    VarReplacer replacer(var_remap);
+    return replacer(expr);
+  }
+
+ private:
+  Expr VisitExpr_(const VarNode* op) final {
+    Var var = GetRef<Var>(op);
+    auto it = var_remap_.find(var->vid);
+    return it == var_remap_.end() ? var : it->second;
+  }
+
+  Expr VisitExpr_(const DataflowVarNode* op) final {
+    Var var = GetRef<Var>(op);
+    auto it = var_remap_.find(var->vid);
+    return it == var_remap_.end() ? var : it->second;
+  }
+
+  const VarMap& var_remap_;
+};
+
 /*!
  * \brief Create a Constant with a scalar
  *
diff --git a/tests/python/relax/test_transform_convert_layout.py 
b/tests/python/relax/test_transform_convert_layout.py
new file mode 100644
index 0000000000..1b65c0812a
--- /dev/null
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -0,0 +1,1352 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.relax.transform import ConvertLayout, Normalize
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+
+def verify(input, expected):
+    mod = ConvertLayout({"relax.nn.conv2d": ["NHWC", "OHWI"]})(input)
+    mod = Normalize()(mod)
+    print(mod.script())
+    tvm.ir.assert_structural_equal(mod, expected)
+
+
+def test_conv2d():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv)
+            return gv
+
+    verify(Input, Expected)
+
+
+def test_conv2d_onlydim():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4)
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor("float32", ndim=4) = R.nn.conv2d(x, w, 
out_dtype="float32")
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32", 
ndim=4)
+        ) -> R.Tensor(dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(x, 
axes=[0, 2, 3, 1])
+                lv1: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, 
axes=[0, 2, 3, 1])
+                lv2: R.Tensor(dtype="float32", ndim=4) = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv2, 
axes=[0, 3, 1, 2])
+                R.output(gv)
+            return gv
+
+    verify(Input, Expected)
+
+
+def test_conv2d_symbolic():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4)
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                N, C, H, W = T.int64(), T.int64(), T.int64(), T.int64()
+                lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32"))
+                gv: R.Tensor("float32", ndim=4) = R.nn.conv2d(lv0, w, 
out_dtype="float32")
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32", 
ndim=4)
+        ) -> R.Tensor(dtype="float32", ndim=4):
+            N = T.int64()
+            C = T.int64()
+            H = T.int64()
+            W = T.int64()
+            with R.dataflow():
+                lv0: R.Tensor((N, C, H, W), dtype="float32") = R.match_cast(
+                    x, R.Tensor((N, C, H, W), dtype="float32")
+                )
+                lv: R.Tensor((N, H, W, C), dtype="float32") = 
R.permute_dims(lv0, axes=[0, 2, 3, 1])
+                lv1: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, 
axes=[0, 2, 3, 1])
+                lv2: R.Tensor(dtype="float32", ndim=4) = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv2, 
axes=[0, 3, 1, 2])
+                R.output(gv)
+            return gv
+
+    verify(Input, Expected)
+
+
+def test_conv2d_matchcast_bias():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4)
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                lv0: R.Tensor("float32", ndim=4) = R.nn.conv2d(x, w, 
out_dtype="float32")
+                N, C, H, W = T.int64(), T.int64(), T.int64(), T.int64()
+                lv1 = R.match_cast(lv0, R.Tensor((N, C, H, W), "float32"))
+                gv = R.add(lv1, w)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(dtype="float32", ndim=4), w: R.Tensor(dtype="float32", 
ndim=4)
+        ) -> R.Tensor(dtype="float32", ndim=4):
+            N = T.int64()
+            H = T.int64()
+            W = T.int64()
+            C = T.int64()
+            with R.dataflow():
+                lv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(x, 
axes=[0, 2, 3, 1])
+                lv1: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, 
axes=[0, 2, 3, 1])
+                lv0: R.Tensor(dtype="float32", ndim=4) = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((N, H, W, C), dtype="float32") = R.match_cast(
+                    lv0, R.Tensor((N, H, W, C), dtype="float32")
+                )
+                lv3: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, 
axes=[0, 2, 3, 1])
+                lv4: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv3)
+                gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv4, 
axes=[0, 3, 1, 2])
+                R.output(gv)
+            return gv
+
+    verify(Input, Expected)
+
+
+def test_conv2d_relu():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv)
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_relu_conv2d_relu():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                x0: R.Tensor((2, 3, 28, 28), "float32") = R.nn.relu(x)
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+                R.output(gv2)
+            return gv2
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x)
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(
+                    x0, axes=[0, 2, 3, 1]
+                )
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv)
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_relu_tanh():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+                gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2)
+                R.output(gv3)
+            return gv3
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv)
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.tanh(gv2)
+                gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv3)
+            return gv3
+
+    verify(Input, Expected)
+
+
+def test_conv2d_add():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"),
+            w: R.Tensor((4, 3, 3, 3), "float32"),
+            bias: R.Tensor((2, 4, 26, 26), "float32"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            bias: R.Tensor((2, 4, 26, 26), dtype="float32"),
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = 
R.permute_dims(
+                    bias, axes=[0, 2, 3, 1]
+                )
+                lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.add(gv, lv2)
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv3, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_add_relu_conv2d():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 4, 28, 28), "float32"),
+            w: R.Tensor((4, 4, 3, 3), "float32"),
+            bias: R.Tensor((2, 4, 26, 26), "float32"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias)
+                gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2)
+                gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, 
out_dtype="float32")
+                R.output(gv4)
+            return gv4
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 4, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 4, 3, 3), dtype="float32"),
+            bias: R.Tensor((2, 4, 26, 26), dtype="float32"),
+        ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 4), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 4), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = 
R.permute_dims(
+                    bias, axes=[0, 2, 3, 1]
+                )
+                gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.add(gv, lv2)
+                gv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv2)
+                lv3: R.Tensor((4, 3, 3, 4), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                lv4: R.Tensor((2, 24, 24, 4), dtype="float32") = R.nn.conv2d(
+                    gv3,
+                    lv3,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = 
R.permute_dims(
+                    lv4, axes=[0, 3, 1, 2]
+                )
+                R.output(gv4)
+            return gv4
+
+    verify(Input, Expected)
+
+
+def test_conv2d_fma_relu_conv2d():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 4, 28, 28), "float32"),
+            w: R.Tensor((4, 4, 3, 3), "float32"),
+            scale: R.Tensor((2, 4, 26, 26), dtype="float32"),
+            bias: R.Tensor((2, 4, 26, 26), "float32"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv, 
scale, bias)
+                gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2)
+                gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, 
out_dtype="float32")
+                R.output(gv4)
+            return gv4
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 4, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 4, 3, 3), dtype="float32"),
+            scale: R.Tensor((2, 4, 26, 26), dtype="float32"),
+            bias: R.Tensor((2, 4, 26, 26), dtype="float32"),
+        ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 4), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 4), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    gv, axes=[0, 3, 1, 2]
+                )
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.ewise_fma(lv2, scale, bias)
+                gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.relu(gv2)
+                lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = 
R.permute_dims(
+                    gv3, axes=[0, 2, 3, 1]
+                )
+                lv4: R.Tensor((4, 3, 3, 4), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                lv5: R.Tensor((2, 24, 24, 4), dtype="float32") = R.nn.conv2d(
+                    lv3,
+                    lv4,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = 
R.permute_dims(
+                    lv5, axes=[0, 3, 1, 2]
+                )
+                R.output(gv4)
+            return gv4
+
+    verify(Input, Expected)
+
+
+def test_conv2d_sum():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=2):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[2, 3])
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=2):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv2: R.Tensor((2, 4), dtype="float32") = R.sum(gv, axis=[1, 
2], keepdims=False)
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_sum_keepdim():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 1, 1), "float32") = R.sum(gv, axis=[2, 
3], keepdims=True)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 1, 1, 4), dtype="float32") = R.sum(gv, 
axis=[1, 2], keepdims=True)
+                gv2: R.Tensor((2, 4, 1, 1), dtype="float32") = R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_transpose():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv, 
axes=[3, 2, 1, 0])
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv2: R.Tensor((26, 26, 4, 2), dtype="float32") = 
R.permute_dims(
+                    gv, axes=[2, 1, 3, 0]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_expand_dims():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=6):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = 
R.expand_dims(gv, axis=(-3, 1))
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=6):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 1, 26, 1, 26, 4), dtype="float32") = 
R.expand_dims(
+                    gv, axis=[-3, 1]
+                )
+                gv2: R.Tensor((2, 1, 4, 1, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 1, 5, 3, 2, 4]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_expand_dims_squeeze():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = 
R.expand_dims(gv, axis=(-3, 1))
+                gv3: R.Tensor((2, 4, 26, 26), "float32") = R.squeeze(gv2, 
axis=[1, 3])
+                R.output(gv3)
+            return gv3
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv2: R.Tensor((2, 1, 26, 1, 26, 4), dtype="float32") = 
R.expand_dims(
+                    gv, axis=[-3, 1]
+                )
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = 
R.squeeze(gv2, axis=[1, 3])
+                gv3: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv3)
+            return gv3
+
+    verify(Input, Expected)
+
+
+def test_conv2d_strided_slice():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice(
+                    gv, begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4], 
axes=[1, 2, 3]
+                )
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 9, 7, 2), dtype="float32") = R.strided_slice(
+                    gv, axes=[3, 1, 2], begin=[0, 0, 0], end=[4, 26, 26], 
strides=[2, 3, 4]
+                )
+                gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_relu_concat():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+                gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), 
axis=1)
+                R.output(gv3)
+            return gv3
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv)
+                lv2: R.Tensor((2, 26, 26, 8), dtype="float32") = R.concat((gv, 
gv2), axis=3)
+                gv3: R.Tensor((2, 8, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv3)
+            return gv3
+
+    verify(Input, Expected)
+
+
+def test_conv2d_relu_concat_split():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 
3), "float32")):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv)
+                gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), 
axis=1)
+                gv4 = R.split(gv3, indices_or_sections=2, axis=1)
+                R.output(gv4)
+            return gv4
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(gv)
+                gv3: R.Tensor((2, 26, 26, 8), dtype="float32") = R.concat((gv, 
gv2), axis=3)
+                gv4: R.Tuple(
+                    R.Tensor((2, 26, 26, 4), dtype="float32"),
+                    R.Tensor((2, 26, 26, 4), dtype="float32"),
+                ) = R.split(gv3, indices_or_sections=2, axis=3)
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = gv4[0]
+                lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                lv4: R.Tensor((2, 26, 26, 4), dtype="float32") = gv4[1]
+                lv5: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv4, axes=[0, 3, 1, 2]
+                )
+                gv5 = (lv3, lv5)
+                R.output(gv5)
+            return gv5
+
+    verify(Input, Expected)
+
+
+def test_conv2d_maxpool2d():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2 = R.nn.max_pool2d(
+                    gv,
+                    pool_size=[2, 2],
+                    strides=[2, 2],
+                    padding=[0, 0],
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 13, 13, 4), dtype="float32") = 
R.nn.max_pool2d(
+                    gv,
+                    pool_size=[2, 2],
+                    strides=[2, 2],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    ceil_mode=False,
+                    layout="NHWC",
+                    out_layout="NHWC",
+                )
+                gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_avgpool2d():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2 = R.nn.adaptive_avg_pool2d(gv, output_size=[13, 13], 
layout="NCHW")
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 13, 13, 4), dtype="float32") = 
R.nn.adaptive_avg_pool2d(
+                    gv, output_size=[13, 13], layout="NHWC", out_layout="NHWC"
+                )
+                gv2: R.Tensor((2, 4, 13, 13), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_softmax():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2 = R.nn.softmax(gv, axis=1)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = 
R.nn.softmax(gv, axis=3)
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_batchnorm():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"),
+            w: R.Tensor((4, 3, 3, 3), "float32"),
+            gamma: R.Tensor((4,), dtype="float32"),
+            beta: R.Tensor((4,), dtype="float32"),
+            moving_mean: R.Tensor((4,), dtype="float32"),
+            moving_var: R.Tensor((4,), dtype="float32"),
+        ):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tuple(
+                    R.Tensor((2, 4, 26, 26), dtype="float32"),
+                    R.Tensor((4,), dtype="float32"),
+                    R.Tensor((4,), dtype="float32"),
+                ) = R.nn.batch_norm(gv, gamma, beta, moving_mean, moving_var, 
axis=1)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            gamma: R.Tensor((4,), dtype="float32"),
+            beta: R.Tensor((4,), dtype="float32"),
+            moving_mean: R.Tensor((4,), dtype="float32"),
+            moving_var: R.Tensor((4,), dtype="float32"),
+        ):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv2: R.Tuple(
+                    R.Tensor((2, 26, 26, 4), dtype="float32"),
+                    R.Tensor((4,), dtype="float32"),
+                    R.Tensor((4,), dtype="float32"),
+                ) = R.nn.batch_norm(
+                    gv,
+                    gamma,
+                    beta,
+                    moving_mean,
+                    moving_var,
+                    axis=3,
+                    epsilon=1.0000000000000001e-05,
+                    center=True,
+                    scale=True,
+                )
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = gv2[0]
+                lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                lv4: R.Tensor((4,), dtype="float32") = gv2[1]
+                lv5: R.Tensor((4,), dtype="float32") = gv2[2]
+                gv3 = (lv3, lv4, lv5)
+                R.output(gv3)
+            return gv3
+
+    verify(Input, Expected)
+
+
+def test_conv2d_layernorm():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"),
+            w: R.Tensor((4, 3, 3, 3), "float32"),
+            gamma: R.Tensor((26, 26), dtype="float32"),
+            beta: R.Tensor((26, 26), dtype="float32"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.nn.layer_norm(
+                    gv, gamma, beta, axes=[-2, -1]
+                )
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            gamma: R.Tensor((26, 26), dtype="float32"),
+            beta: R.Tensor((26, 26), dtype="float32"),
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = 
R.nn.layer_norm(
+                    gv,
+                    gamma,
+                    beta,
+                    axes=[1, 2],
+                    epsilon=1.0000000000000001e-05,
+                    center=True,
+                    scale=True,
+                )
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_resize2d():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2 = R.image.resize2d(gv, (52, 52), layout="NCHW")
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 52, 52, 4), dtype="float32") = 
R.image.resize2d(
+                    gv,
+                    (52, 52),
+                    roi=[T.float32(0), T.float32(0), T.float32(0), 
T.float32(0)],
+                    layout="NHWC",
+                    method="linear",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="round",
+                    cubic_alpha=-0.5,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="void",
+                )
+                gv2: R.Tensor((2, 4, 52, 52), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_unknown_bias_dim():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"),
+            w: R.Tensor((4, 3, 3, 3), "float32"),
+            w2: R.Tensor(dtype="float32"),
+        ) -> R.Tensor(None, "float32"):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2 = w2 + gv
+                R.output(gv2)
+            return gv2
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            w2: R.Tensor(dtype="float32"),
+        ) -> R.Tensor(dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    gv, axes=[0, 3, 1, 2]
+                )
+                gv2: R.Tensor(dtype="float32") = R.add(w2, lv2)
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_binary_broadcast():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"),
+            w: R.Tensor((4, 3, 3, 3), "float32"),
+            bias: R.Tensor((26, 26), "float32"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias)
+                R.output(gv2)
+            return gv2
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            bias: R.Tensor((26, 26), dtype="float32"),
+        ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = 
R.permute_dims(
+                    gv, axes=[0, 3, 1, 2]
+                )
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.add(lv2, 
bias)
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to