This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b3b6024027 [Relax] Add FRelaxInferLayout for gather_elements operator 
(#18642)
b3b6024027 is described below

commit b3b6024027c9b83471880dfb7af892c618274131
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Jan 7 19:33:01 2026 +0800

    [Relax] Add FRelaxInferLayout for gather_elements operator (#18642)
    
    ## Why
    
    The gather_elements operator lacked layout inference support, preventing
    it from participating in layout transformations during the ConvertLayout
    pass.
    
    ## How
    
    - Add InferLayoutGatherElements function that transforms the axis
    attribute according to the input layout
    - Register FRelaxInferLayout attribute
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 src/relax/op/tensor/manipulate.cc                  | 34 ++++++++++++++
 .../python/relax/test_transform_convert_layout.py  | 53 ++++++++++++++++++++++
 2 files changed, 87 insertions(+)

diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 3170b28eeb..afb749a297 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2150,12 +2150,46 @@ StructInfo InferStructInfoGatherElements(const Call& 
call, const BlockBuilder& c
   return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim, 
data_sinfo->vdevice);
 }
 
+InferLayoutOutput InferLayoutGatherElements(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* attrs = call->attrs.as<GatherElementsAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision data_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, 
call->args[1]);
+
+  LayoutDecision layout = data_layout;
+  // If data_layout is initial and indices_layout is not, prefer 
indices_layout.
+  bool data_is_initial =
+      data_layout->layout.name() == 
InitialLayout(data_layout->layout.ndim()).name();
+  bool indices_is_initial =
+      indices_layout->layout.name() == 
InitialLayout(indices_layout->layout.ndim()).name();
+  if (data_is_initial && !indices_is_initial) {
+    layout = indices_layout;
+  }
+
+  if (layout->layout.ndim() != layout->layout.ndim_primal()) {
+    const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+    ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+    ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for 
now";
+    int ndim = tensor_sinfo->ndim;
+    layout = LayoutDecision(InitialLayout(ndim));
+  }
+
+  ObjectPtr<GatherElementsAttrs> new_attrs = 
ffi::make_object<GatherElementsAttrs>(*attrs);
+  new_attrs->axis = FindAxis(layout->layout, attrs->axis->value);
+  return InferLayoutOutput({layout, layout}, {layout}, Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.gather_elements")
     .set_attrs_type<GatherElementsAttrs>()
     .set_num_inputs(2)
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("indices", "Tensor", "The indices tensor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoGatherElements)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", 
InferLayoutGatherElements)
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.gather_nd */
diff --git a/tests/python/relax/test_transform_convert_layout.py 
b/tests/python/relax/test_transform_convert_layout.py
index 221d680ebc..84fa9e70c7 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -5434,5 +5434,58 @@ def test_conv2d_scatter_nd():
     verify(Input, Expected)
 
 
+def test_conv2d_gather_elements():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"),
+            w: R.Tensor((4, 3, 3, 3), "float32"),
+            indices: R.Tensor((2, 4, 26, 26), "int64"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                data: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                gv = R.gather_elements(data, indices, axis=1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            indices: R.Tensor((2, 4, 26, 26), dtype="int64"),
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                data: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((2, 26, 26, 4), dtype="int64") = R.permute_dims(
+                    indices, axes=[0, 2, 3, 1]
+                )
+                lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = 
R.gather_elements(
+                    data, lv2, axis=3
+                )
+                gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
+                    lv3, axes=[0, 3, 1, 2]
+                )
+                R.output(gv)
+            return gv
+
+    verify(Input, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to