This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a95b90c375 [Relax] Add FRelaxInferLayout for flip operator (#18637)
a95b90c375 is described below
commit a95b90c375ccee09eef48c668ab01824187da5f5
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Jan 6 14:26:05 2026 +0800
[Relax] Add FRelaxInferLayout for flip operator (#18637)
## Why
The flip operator lacked layout inference support, preventing it from
participating in layout transformations during the ConvertLayout pass.
## How
- Add InferLayoutFlip function that transforms the axis attribute
according to the input layout
- Register FRelaxInferLayout attribute for relax.flip operator
- Add test case for conv2d followed by flip with layout conversion
---
src/relax/op/tensor/manipulate.cc | 33 ++++++++++++++++
.../python/relax/test_transform_convert_layout.py | 44 ++++++++++++++++++++++
2 files changed, 77 insertions(+)
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 4ac7affb0c..22636afb97 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2047,11 +2047,44 @@ StructInfo InferStructInfoFlip(const Call& call, const
BlockBuilder& ctx) {
return data_sinfo;
}
+InferLayoutOutput InferLayoutFlip(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ const auto* attrs = call->attrs.as<FlipAttrs>();
+ ICHECK(attrs != nullptr) << "Invalid Call";
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+
+ LayoutDecision existing_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ int ndim = tensor_sinfo->ndim;
+
+ if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal())
{
+ existing_layout = LayoutDecision(InitialLayout(ndim));
+ }
+
+ int axis = attrs->axis.IntValue();
+ if (axis < 0) {
+ axis += ndim;
+ }
+
+ const int new_axis = FindAxis(existing_layout->layout, axis);
+ ICHECK_GE(new_axis, 0) << "Failed to find transformed axis";
+
+ ObjectPtr<FlipAttrs> new_attrs = ffi::make_object<FlipAttrs>(*attrs);
+ new_attrs->axis = Integer(new_axis);
+
+ return InferLayoutOutput({existing_layout}, {existing_layout},
Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.flip")
.set_attrs_type<FlipAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlip)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutFlip)
.set_attr<Bool>("FPurity", Bool(true));
/* relax.gather_elements */
diff --git a/tests/python/relax/test_transform_convert_layout.py
b/tests/python/relax/test_transform_convert_layout.py
index 5ba0c4d867..8ae96e9c07 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -5283,5 +5283,49 @@ def test_conv2d_dynamic_strided_slice():
verify(Input, Expected)
+def test_conv2d_flip():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 4, 26, 26), "float32") = R.flip(gv, axis=1)
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.flip(gv,
axis=3)
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()