This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new da52d7d0f6 fix: Support 5D volumetric inputs in ONNX GridSample 
frontend converter (#19816)
da52d7d0f6 is described below

commit da52d7d0f62e2831212763189870caf447c85bb0
Author: Matt Van Horn <[email protected]>
AuthorDate: Thu Jun 18 12:24:23 2026 -0700

    fix: Support 5D volumetric inputs in ONNX GridSample frontend converter 
(#19816)
    
    ## Summary
    The Relax ONNX frontend's GridSample._impl_v16 converter unconditionally
    permutes the grid from ONNX [N,H,W,2] to TVM [N,2,H,W] and calls
    image.grid_sample with layout="NCHW". For 5D volumetric inputs
    ([N,C,D,H,W] with grid [N,D,H,W,3]) this crashes at permute_dims with an
    InternalError ('PermuteDims expects the number of input axes to equal
    the ndim of the input tensor.
    
    ## Changes
    In GridSample._impl_v16, read data.struct_info.ndim and dispatch on
    rank. For ndim==4, keep the existing permute_dims(grid,[0,3,1,2]) +
    grid_sample(layout="NCHW").
    
    Fixes #19688
    
    ---------
    
    Co-authored-by: Matt Van Horn <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  28 ++++-
 src/relax/op/image/resize.cc                    |  35 +++++--
 tests/python/relax/test_frontend_onnx.py        | 134 +++++++++++++++++++++++-
 3 files changed, 181 insertions(+), 16 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index a8cb216e26..3cfe7c892c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -5084,15 +5084,35 @@ class GridSample(OnnxOpConverter):
 
         align_corners = bool(attr.get("align_corners", 0))
 
-        # ONNX grid shape: [N, H_out, W_out, 2]
-        # TVM grid shape:  [N, 2, H_out, W_out]
-        grid = relax.op.permute_dims(grid, [0, 3, 1, 2])
+        if hasattr(data.struct_info, "ndim"):
+            ndim = data.struct_info.ndim
+        else:
+            ndim = len(data.struct_info.shape)
+
+        if ndim == 5 and method == "bicubic":
+            raise NotImplementedError(
+                "5D (volumetric) GridSample with mode='cubic' is not supported 
"
+                "(TOPI 3D grid_sample supports only bilinear and nearest)."
+            )
+
+        if ndim == 4:
+            # ONNX grid shape: [N, H_out, W_out, 2]
+            # TVM grid shape:  [N, 2, H_out, W_out]
+            grid = relax.op.permute_dims(grid, [0, 3, 1, 2])
+            layout = "NCHW"
+        elif ndim == 5:
+            # ONNX grid shape: [N, D_out, H_out, W_out, 3]
+            # TVM grid shape:  [N, 3, D_out, H_out, W_out]
+            grid = relax.op.permute_dims(grid, [0, 4, 1, 2, 3])
+            layout = "NCDHW"
+        else:
+            raise NotImplementedError(f"GridSample only supports 4D or 5D 
input, got {ndim}D.")
 
         return relax.op.image.grid_sample(
             data,
             grid,
             method=method,
-            layout="NCHW",
+            layout=layout,
             padding_mode=padding_mode,
             align_corners=align_corners,
         )
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 1b84f3dfc8..653ea04c63 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -305,14 +305,19 @@ StructInfo InferStructInfoGridSample(const Call& call, 
const BlockBuilder& ctx)
   }
 
   const auto* attrs = call->attrs.as<GridSampleAttrs>();
-  auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout,
-                                                    /*tgt_layout=*/"NCHW",
-                                                    /*tensor_name=*/"data");
+
+  // grid_sample supports both 2D (NCHW) and 3D (NCDHW) sampling. The frontend
+  // sets attrs->layout to "NCDHW" for the volumetric case; everything else is
+  // treated as the 2D NCHW path so existing behavior is preserved.
+  const bool is_ncdhw = (attrs->layout == "NCDHW");
+
+  auto [data_layout, data2tgt] =
+      CheckTensorLayout(call, ctx, attrs->layout,
+                        /*tgt_layout=*/is_ncdhw ? "NCDHW" : "NCHW",
+                        /*tensor_name=*/"data");
 
   DataType out_dtype = data_sinfo->dtype;
 
-  // Output shape: [N, C, grid_H, grid_W]
-  // grid shape for NCHW layout input is [N, H_out, W_out, 2]
   ffi::Optional<ShapeExpr> data_shape = CheckNdimPerLayoutAndGetShape(
       call, ctx, ffi::GetRef<TensorStructInfo>(data_sinfo), data_layout);
   const auto* grid_shape = grid_sinfo->shape.as<ShapeExprNode>();
@@ -321,13 +326,21 @@ StructInfo InferStructInfoGridSample(const Call& call, 
const BlockBuilder& ctx)
     return TensorStructInfo(out_dtype, data_layout.ndim(), 
data_sinfo->vdevice);
   }
 
-  ffi::Array<PrimExpr> data_NCHW_shape = 
data2NCHW.ForwardShape(data_shape.value()->values);
-  // grid is [N, H_out, W_out, 2], output is [N, C, H_out, W_out]
-  ffi::Array<PrimExpr> out_NCHW_shape(data_NCHW_shape);
-  out_NCHW_shape.Set(2, grid_shape->values[1]);  // H_out
-  out_NCHW_shape.Set(3, grid_shape->values[2]);  // W_out
+  ffi::Array<PrimExpr> data_tgt_shape = 
data2tgt.ForwardShape(data_shape.value()->values);
+  ffi::Array<PrimExpr> out_tgt_shape(data_tgt_shape);
+  if (is_ncdhw) {
+    // grid (TVM layout) is [N, 3, D_out, H_out, W_out], output is
+    // [N, C, D_out, H_out, W_out]; the spatial extents are grid->values[2:].
+    out_tgt_shape.Set(2, grid_shape->values[2]);  // D_out
+    out_tgt_shape.Set(3, grid_shape->values[3]);  // H_out
+    out_tgt_shape.Set(4, grid_shape->values[4]);  // W_out
+  } else {
+    // grid (TVM layout) is [N, 2, H_out, W_out], output is [N, C, H_out, 
W_out]
+    out_tgt_shape.Set(2, grid_shape->values[2]);  // H_out
+    out_tgt_shape.Set(3, grid_shape->values[3]);  // W_out
+  }
 
-  ffi::Array<PrimExpr> out_shape = data2NCHW.BackwardShape(out_NCHW_shape);
+  ffi::Array<PrimExpr> out_shape = data2tgt.BackwardShape(out_tgt_shape);
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype, 
data_sinfo->vdevice);
 }
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 5aff95da5a..57f780868c 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -5629,7 +5629,6 @@ def test_affine_grid():
 @pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
 @pytest.mark.parametrize("align_corners", [0, 1])
 def test_grid_sample(mode, padding_mode, align_corners):
-    # Only testing 2D (NCHW) as that's what TVM currently supports
     x_shape = [1, 3, 4, 4]
     grid_shape = [1, 2, 2, 2]
     out_shape = [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2]]
@@ -5668,6 +5667,139 @@ def test_grid_sample(mode, padding_mode, align_corners):
     )
 
 
[email protected]("mode", ["bilinear", "nearest"])
[email protected]("padding_mode", ["zeros", "border", "reflection"])
[email protected]("align_corners", [0, 1])
+def test_grid_sample_5d(mode, padding_mode, align_corners):
+    x_shape = [1, 1, 4, 4, 4]
+    grid_shape = [1, 4, 4, 4, 3]
+    out_shape = [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2], 
grid_shape[3]]
+
+    node = helper.make_node(
+        "GridSample",
+        inputs=["X", "grid"],
+        outputs=["Y"],
+        mode=mode,
+        padding_mode=padding_mode,
+        align_corners=align_corners,
+    )
+
+    graph = helper.make_graph(
+        [node],
+        "grid_sample_5d_test",
+        inputs=[
+            helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+            helper.make_tensor_value_info("grid", TensorProto.FLOAT, 
grid_shape),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("Y", TensorProto.FLOAT, out_shape),
+        ],
+    )
+
+    rng = np.random.default_rng(0)
+    grid_data = rng.uniform(-1.25, 1.25, grid_shape).astype("float32")
+    x_data = rng.uniform(-1, 1, x_shape).astype("float32")
+
+    model = helper.make_model(graph, producer_name="grid_sample_5d_test")
+    check_correctness(
+        model,
+        inputs={"grid": grid_data, "X": x_data},
+        opset=16,
+        rtol=1e-5,
+        atol=1e-5,
+    )
+
+
+def test_grid_sample_5d_cubic_unsupported():
+    x_shape = [1, 1, 4, 4, 4]
+    grid_shape = [1, 2, 3, 5, 3]
+    out_shape = [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2], 
grid_shape[3]]
+
+    node = helper.make_node(
+        "GridSample",
+        inputs=["X", "grid"],
+        outputs=["Y"],
+        mode="cubic",
+    )
+
+    graph = helper.make_graph(
+        [node],
+        "grid_sample_5d_cubic_unsupported_test",
+        inputs=[
+            helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+            helper.make_tensor_value_info("grid", TensorProto.FLOAT, 
grid_shape),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("Y", TensorProto.FLOAT, out_shape),
+        ],
+    )
+
+    model = helper.make_model(graph, 
producer_name="grid_sample_5d_cubic_unsupported_test")
+    with pytest.raises(
+        NotImplementedError,
+        match="5D .*GridSample with mode='cubic' is not supported",
+    ):
+        from_onnx(model, opset=16, keep_params_in_input=True)
+
+
+def test_grid_sample_4d_non_square_output_shape():
+    x_shape = [1, 3, 4, 4]
+    grid_shape = [1, 3, 5, 2]
+    out_shape = [x_shape[0], x_shape[1], grid_shape[1], grid_shape[2]]
+
+    node = helper.make_node(
+        "GridSample",
+        inputs=["X", "grid"],
+        outputs=["Y"],
+        mode="bilinear",
+    )
+
+    graph = helper.make_graph(
+        [node],
+        "grid_sample_4d_non_square_output_shape_test",
+        inputs=[
+            helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+            helper.make_tensor_value_info("grid", TensorProto.FLOAT, 
grid_shape),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("Y", TensorProto.FLOAT, out_shape),
+        ],
+    )
+
+    model = helper.make_model(graph, 
producer_name="grid_sample_4d_non_square_output_shape_test")
+    tvm_model = from_onnx(model, opset=16, keep_params_in_input=True)
+    inferred_shape = tuple(dim.value for dim in 
tvm_model["main"].ret_struct_info.shape.values)
+    assert inferred_shape == tuple(out_shape)
+
+
+def test_grid_sample_unsupported_rank():
+    x_shape = [1, 3, 4]
+    grid_shape = [1, 4, 2]
+
+    node = helper.make_node(
+        "GridSample",
+        inputs=["X", "grid"],
+        outputs=["Y"],
+        mode="bilinear",
+    )
+
+    graph = helper.make_graph(
+        [node],
+        "grid_sample_unsupported_rank_test",
+        inputs=[
+            helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+            helper.make_tensor_value_info("grid", TensorProto.FLOAT, 
grid_shape),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("Y", TensorProto.FLOAT, x_shape),
+        ],
+    )
+
+    model = helper.make_model(graph, 
producer_name="grid_sample_unsupported_rank_test")
+    with pytest.raises(NotImplementedError, match="GridSample only supports 4D 
or 5D input"):
+        from_onnx(model, opset=16, keep_params_in_input=True)
+
+
 def test_grid_sample_linear_mode_translation():
     """Test that ONNX mode='linear' is correctly translated to 'bilinear'.
 

Reply via email to