This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 44d973b0aa [Relax] Add layout inference support for repeat operator 
(#18579)
44d973b0aa is described below

commit 44d973b0aa939307a36aef1011de30833837c664
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Dec 18 11:56:20 2025 +0800

    [Relax] Add layout inference support for repeat operator (#18579)
    
    ## How
    
    - Implemented InferLayoutRepeat function that:
      - Preserves layout when axis is specified (with axis transformation)
      - Returns 1D layout when axis is not specified (flatten mode)
    - Transforms the axis parameter based on layout changes (e.g., NCHW
    axis=1 → NHWC axis=3)
---
 src/relax/op/tensor/manipulate.cc                  | 60 ++++++++++++++-
 .../python/relax/test_transform_convert_layout.py  | 85 ++++++++++++++++++++++
 2 files changed, 144 insertions(+), 1 deletion(-)

diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 0310c7f46b..493198fbd0 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1805,12 +1805,70 @@ StructInfo InferStructInfoRepeat(const Call& call, 
const BlockBuilder& ctx) {
   return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype, 
data_sinfo->vdevice);
 }
 
-// TODO(relax-team): implement FRelaxInferLayout for repeat
+InferLayoutOutput InferLayoutRepeat(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  const auto* attrs = call->attrs.as<RepeatAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+
+  LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  int ndim = tensor_sinfo->ndim;
+
+  // Can't handle sub indexed layouts.
+  if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) 
{
+    existing_layout = LayoutDecision(InitialLayout(ndim));
+  }
+
+  // When axis is not specified, the output is 1D (flattened)
+  if (!attrs->axis.has_value()) {
+    return InferLayoutOutput({existing_layout}, {InitialLayoutDecision(1)}, 
Attrs(call->attrs));
+  }
+
+  // Transform the axis based on the layout
+  int axis = attrs->axis.value();
+  if (axis < 0) {
+    axis += ndim;
+  }
+
+  // Create a mapping from original layout to existing layout
+  std::string axis_str(ndim, '0');
+  axis_str[axis] = '1';
+  for (int i = 0, j = 0; i < ndim; ++i) {
+    if (axis_str[i] != '1') {
+      axis_str[i] = 'A' + j++;
+    }
+  }
+
+  ffi::String new_axis_str =
+      TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout);
+
+  int64_t new_axis = -1;
+  for (size_t i = 0; i < new_axis_str.size(); ++i) {
+    if (new_axis_str.at(i) == '1') {
+      new_axis = i;
+      break;
+    }
+  }
+  ICHECK_GE(new_axis, 0) << "Failed to find transformed axis";
+
+  ObjectPtr<RepeatAttrs> new_attrs = ffi::make_object<RepeatAttrs>(*attrs);
+  new_attrs->axis = new_axis;
+
+  // When axis is specified, the layout is preserved
+  return InferLayoutOutput({existing_layout}, {existing_layout}, 
Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.repeat")
     .set_attrs_type<RepeatAttrs>()
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRepeat)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutRepeat)
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.tile */
diff --git a/tests/python/relax/test_transform_convert_layout.py 
b/tests/python/relax/test_transform_convert_layout.py
index 83b81a6898..95f043ef66 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -4992,5 +4992,90 @@ def test_pooling_branching_texture_params():
     verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]})
 
 
+def test_conv2d_repeat():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((2, 8, 26, 26), "float32") = R.repeat(gv, 
repeats=2, axis=1)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 26, 26, 8), dtype="float32") = R.repeat(gv, 
repeats=2, axis=3)
+                gv2: R.Tensor((2, 8, 26, 26), dtype="float32") = 
R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
+def test_conv2d_repeat_flatten():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), 
"float32")
+        ) -> R.Tensor((5408,), "float32"):
+            with R.dataflow():
+                gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv2: R.Tensor((5408,), "float32") = R.repeat(gv, repeats=1)
+                R.output(gv2)
+            return gv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 
3, 3), dtype="float32")
+        ) -> R.Tensor((5408,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                gv2: R.Tensor((5408,), dtype="float32") = R.repeat(gv, 
repeats=1)
+                R.output(gv2)
+            return gv2
+
+    verify(Input, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to