This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b975db9b28 [Relax] Add FRelaxInferLayout for scatter_nd operator 
(#18643)
b975db9b28 is described below

commit b975db9b28959503e471da9c78b41df9a16d738e
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Jan 7 17:53:27 2026 +0800

    [Relax] Add FRelaxInferLayout for scatter_nd operator (#18643)
    
    ## Why
    
    The scatter_nd operator was missing FRelaxInferLayout attribute, which
    is needed for proper layout transformation during model optimization.
    
    ### How
    
    - Added InferLayoutScatterND function that uses data tensor's layout for
    output since scatter_nd maintains input shape
    - Registered FRelaxInferLayout attribute
---
 src/relax/op/tensor/manipulate.cc                  | 40 +++++++++++++++++
 .../python/relax/test_transform_convert_layout.py  | 52 ++++++++++++++++++++++
 2 files changed, 92 insertions(+)

diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 7c5682d462..3170b28eeb 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2780,6 +2780,45 @@ StructInfo InferStructInfoScatterND(const Call& call, 
const BlockBuilder& ctx) {
   return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, 
data_sinfo->vdevice);
 }
 
+InferLayoutOutput InferLayoutScatterND(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  LayoutDecision data_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, 
call->args[1]);
+  LayoutDecision updates_layout = GetLayoutDecision(var_layout_map, 
call->args[2]);
+
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* updates_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+  ICHECK(data_sinfo != nullptr) << "Invalid Call";
+  ICHECK(updates_sinfo != nullptr) << "Invalid Call";
+  ICHECK(!data_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+  ICHECK(!updates_sinfo->IsUnknownNdim()) << "Only support static ndim for 
now";
+
+  LayoutDecision layout = data_layout;
+  LayoutDecision out_updates_layout = updates_layout;
+
+  // Check if data has a sub-indexed layout
+  bool has_sub_indexed_layout = layout->layout.ndim() != 
layout->layout.ndim_primal();
+
+  if (has_sub_indexed_layout) {
+    // Fall back to initial layouts for both data and updates
+    layout = LayoutDecision(InitialLayout(data_sinfo->ndim));
+    out_updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim));
+  } else if (data_sinfo->ndim == updates_sinfo->ndim) {
+    // When data and updates have the same rank, apply the same layout to both
+    out_updates_layout = layout;
+  } else {
+    // Different ranks - fall back to initial layouts for both
+    layout = LayoutDecision(InitialLayout(data_sinfo->ndim));
+    out_updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim));
+  }
+
+  return InferLayoutOutput({layout, indices_layout, out_updates_layout}, 
{layout},
+                           Attrs(call->attrs));
+}
+
 TVM_REGISTER_OP("relax.scatter_nd")
     .set_attrs_type<ScatterNDAttrs>()
     .set_num_inputs(3)
@@ -2787,6 +2826,7 @@ TVM_REGISTER_OP("relax.scatter_nd")
     .add_argument("indices", "Tensor", "The indices tensor.")
     .add_argument("updates", "Tensor", "The input tensor of updates.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutScatterND)
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.scatter_nd */
diff --git a/tests/python/relax/test_transform_convert_layout.py 
b/tests/python/relax/test_transform_convert_layout.py
index 26990bc44d..221d680ebc 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -5382,5 +5382,57 @@ def test_conv2d_scatter_elements():
     verify(Input, Expected)
 
 
+def test_conv2d_scatter_nd():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"),
+            w: R.Tensor((4, 3, 3, 3), "float32"),
+            indices: R.Tensor((2, 1), "int64"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                data: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                updates: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(data)
+                gv = R.scatter_nd(data, indices, updates)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            indices: R.Tensor((2, 1), dtype="int64"),
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                data: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                updates: R.Tensor((2, 26, 26, 4), dtype="float32") = 
R.nn.relu(data)
+                lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.scatter_nd(
+                    data, indices, updates, reduction="update"
+                )
+                gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
+                    lv2, axes=[0, 3, 1, 2]
+                )
+                R.output(gv)
+            return gv
+
+    verify(Input, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to