This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6383c7fd7f [Relax][ONNX] Support 3D AffineGrid (#19863)
6383c7fd7f is described below

commit 6383c7fd7f9edbd73e8d81a0b0f39d19e3b0e331
Author: Guan-Ming Chiu <[email protected]>
AuthorDate: Wed Jul 1 03:35:22 2026 +0800

    [Relax][ONNX] Support 3D AffineGrid (#19863)
    
    ## Related Issue
    
    closes #19689
    
    ## Why
    
    The Relax AffineGrid op only handled 2D (4D theta/grid); 5D 3D inputs
    from ONNX failed.
    
    ## How
    
    - Generalize struct-info inference to 2D/3D via spatial =
    size_sinfo->ndim.
    - Branch TOPI affine_grid compute on 2D vs 3D.
    - Add the 3D permute path in the frontend and a test_affine_grid_3d
    case.
    
    ---------
    
    Signed-off-by: Guan-Ming (Wesley) Chiu 
<[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py    | 17 ++++----
 python/tvm/relax/op/image/image.py                 | 15 ++++---
 python/tvm/topi/image/grid_sample.py               | 46 +++++++++++-----------
 src/relax/op/image/resize.cc                       | 38 ++++++++++--------
 tests/python/relax/test_frontend_onnx.py           | 31 +++++++++++++++
 .../relax/test_transform_legalize_ops_image.py     |  8 ++--
 6 files changed, 96 insertions(+), 59 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 87675a3f9c..7a0f25bf83 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3342,15 +3342,14 @@ class AffineGrid(OnnxOpConverter):
         else:
             raise NotImplementedError(f"Dynamic size of type {type(size)} is 
not supported")
 
-        # Only 2D is supported: size = [N, C, H, W]
-        if len(size_vals) != 4:
-            raise ValueError("Only 2D AffineGrid (size=[N,C,H,W]) is 
supported")
-        target_h, target_w = size_vals[2], size_vals[3]
-
-        # Relax affine_grid outputs [N, 2, H, W]
-        grid = bb.emit(relax.op.image.affine_grid(theta, (target_h, target_w), 
align_corners))
-        # Permute to ONNX convention [N, H, W, 2]
-        return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))
+        if len(size_vals) not in (4, 5):
+            raise ValueError("AffineGrid expects size to be [N,C,H,W] (2D) or 
[N,C,D,H,W] (3D)")
+
+        # relax affine_grid outputs [N, spatial, *spatial_dims]; move the 
coord axis
+        # last to match the ONNX convention [N, *spatial_dims, spatial].
+        grid = bb.emit(relax.op.image.affine_grid(theta, tuple(size_vals[2:]), 
align_corners))
+        axes = [0, *range(2, len(size_vals)), 1]
+        return bb.emit(relax.op.permute_dims(grid, axes=axes))
 
 
 class Einsum(OnnxOpConverter):
diff --git a/python/tvm/relax/op/image/image.py 
b/python/tvm/relax/op/image/image.py
index 29aa8457d2..91d7468808 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -239,7 +239,7 @@ def affine_grid(
     size: Expr | SizeLike,
     align_corners: bool = True,
 ) -> Expr:
-    """Generate a 2D sampling grid using an affine transformation matrix.
+    """Generate a 2D or 3D sampling grid using an affine transformation matrix.
 
     This operation is described in https://arxiv.org/pdf/1506.02025.pdf.
     It generates a uniform sampling grid within the target shape, normalizes it
@@ -248,11 +248,13 @@ def affine_grid(
     Parameters
     ----------
     data : relax.Expr
-        The input affine matrix tensor with shape [batch, 2, 3].
+        The input affine matrix tensor with shape [batch, 2, 3] for 2D or
+        [batch, 3, 4] for 3D.
 
-    size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, PrimExprLike]]
-        The target output spatial shape (H, W). If a single integer or PrimExpr
-        is provided, it is interpreted as a square output shape (size, size).
+    size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, ...]]
+        The target output spatial shape, (H, W) for 2D or (D, H, W) for 3D. If 
a
+        single integer or PrimExpr is provided, it is interpreted as a square 
2D
+        output shape (size, size).
 
     align_corners : bool
         If True, normalized grid coordinates map to corner pixels; if False, to
@@ -261,7 +263,8 @@ def affine_grid(
     Returns
     -------
     result : relax.Expr
-        The output grid tensor with shape [batch, 2, H, W].
+        The output grid tensor with shape [batch, 2, H, W] for 2D or
+        [batch, 3, D, H, W] for 3D.
     """
     if isinstance(size, int | PrimExpr):
         size = (size, size)
diff --git a/python/tvm/topi/image/grid_sample.py 
b/python/tvm/topi/image/grid_sample.py
index cdfb7f4362..b5058c6d4a 100644
--- a/python/tvm/topi/image/grid_sample.py
+++ b/python/tvm/topi/image/grid_sample.py
@@ -21,7 +21,7 @@ from tvm import te, tirx
 
 
 def affine_grid(data, target_shape, align_corners=True):
-    """affine_grid operator that generates 2D sampling grid.
+    """affine_grid operator that generates a 2D or 3D sampling grid.
 
     This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It 
generates a uniform
     sampling grid within the target shape and normalizes it to [-1, 1]. The 
provided affine
@@ -30,10 +30,10 @@ def affine_grid(data, target_shape, align_corners=True):
     Parameters
     ----------
     data : tvm.Tensor
-        3-D with shape [batch, 2, 3]. The affine matrix.
+        3-D with shape [batch, 2, 3] for 2D or [batch, 3, 4] for 3D. The 
affine matrix.
 
-    target_shape: list/tuple of two int
-        Specifies the output shape (H, W).
+    target_shape: list/tuple of int
+        Specifies the output spatial shape (H, W) for 2D or (D, H, W) for 3D.
 
     align_corners : bool
         If True, normalized coordinates map to corner pixels; if False, to 
pixel centers
@@ -42,35 +42,35 @@ def affine_grid(data, target_shape, align_corners=True):
     Returns
     -------
     Output : tvm.Tensor
-        4-D with shape [batch, 2, target_height, target_width]
+        [batch, 2, H, W] for 2D or [batch, 3, D, H, W] for 3D.
     """
-    assert target_shape is not None
-    assert len(target_shape) == 2
+    assert len(target_shape) in (2, 3)
     if align_corners:
-        assert target_shape[0] > 1 and target_shape[1] > 1, (
-            "target height/width should be greater than 1 when align_corners 
is True"
+        assert all(s > 1 for s in target_shape), (
+            "target spatial dims should be greater than 1 when align_corners 
is True"
         )
 
     dtype = data.dtype
-    height, width = target_shape[0], target_shape[1]
     if align_corners:
-        y_step = tirx.const((2.0 - 1e-7) / (height - 1), dtype=dtype)
-        x_step = tirx.const((2.0 - 1e-7) / (width - 1), dtype=dtype)
-        y_start = tirx.const(-1.0, dtype=dtype)
-        x_start = tirx.const(-1.0, dtype=dtype)
+        starts = [tirx.const(-1.0, dtype=dtype) for _ in target_shape]
+        steps = [tirx.const((2.0 - 1e-7) / (s - 1), dtype=dtype) for s in 
target_shape]
     else:
         # Pixel centers: coordinate i maps to (2 * i + 1) / size - 1.
-        y_step = tirx.const(2.0 / height, dtype=dtype)
-        x_step = tirx.const(2.0 / width, dtype=dtype)
-        y_start = tirx.const(-1.0 + 1.0 / height, dtype=dtype)
-        x_start = tirx.const(-1.0 + 1.0 / width, dtype=dtype)
+        starts = [tirx.const(-1.0 + 1.0 / s, dtype=dtype) for s in 
target_shape]
+        steps = [tirx.const(2.0 / s, dtype=dtype) for s in target_shape]
 
-    def _compute(n, dim, i, j):
-        y = y_start + i * y_step
-        x = x_start + j * x_step
-        return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
+    ndim = len(target_shape)
 
-    oshape = (data.shape[0], len(target_shape), *target_shape)
+    def _compute(n, dim, *coords):
+        # coords are ordered slowest-to-fastest (e.g. (k, i, j)); the affine 
matrix
+        # columns are fastest-to-slowest (x, y, z), so index it in reverse.
+        val = data[n, dim, ndim]  # translation column
+        for r in range(ndim):
+            coord = starts[r] + coords[r] * steps[r]
+            val += data[n, dim, ndim - 1 - r] * coord
+        return val
+
+    oshape = (data.shape[0], ndim, *target_shape)
     return te.compute(oshape, _compute, tag="affine_grid")
 
 
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 7dd80bf22a..dcd36465b6 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -393,35 +393,38 @@ Type InferTypeAffineGrid(const Call& call, const 
BlockBuilder& ctx) {
         << "AffineGrid expects the target size to be a Shape, while the given 
one is "
         << call->args[1]->GetTypeKey();
   }
-  if (size_ty->ndim != 2) {
+  // 2D output uses a 2-dim spatial size ([H, W]); 3D uses a 3-dim size ([D, 
H, W]).
+  if (size_ty->ndim != 2 && size_ty->ndim != 3) {
     TVM_FFI_VISIT_THROW(ValueError, call)
-        << "AffineGrid expects the target size to be a 2-dim shape, while the 
given "
+        << "AffineGrid expects the target size to be a 2-dim or 3-dim shape, 
while the given "
            "one has ndim "
         << size_ty->ndim;
   }
+  const int spatial = size_ty->ndim;
 
-  // data should be 3-D: [batch, 2, 3]
+  // data should be 3-D: [batch, spatial, spatial + 1] (i.e. [N, 2, 3] or [N, 
3, 4]).
   if (data_ty->ndim != -1 && data_ty->ndim != 3) {
-    TVM_FFI_VISIT_THROW(ValueError, call)
-        << "AffineGrid expects the input data to be 3-D (batch, 2, 3), but got 
ndim "
-        << data_ty->ndim;
+    TVM_FFI_VISIT_THROW(ValueError, call) << "AffineGrid expects the input 
data to be 3-D (batch, "
+                                             "spatial, spatial + 1), but got 
ndim "
+                                          << data_ty->ndim;
   }
 
   const auto* data_shape = data_ty->shape.as<ShapeExprNode>();
   if (data_shape != nullptr) {
-    // Check that the affine matrix has shape [batch, 2, 3]
     if (data_shape->values.size() >= 2) {
       auto* dim1 = data_shape->values[1].as<IntImmNode>();
-      if (dim1 != nullptr && dim1->value != 2) {
+      if (dim1 != nullptr && dim1->value != spatial) {
         TVM_FFI_VISIT_THROW(ValueError, call)
-            << "AffineGrid expects the second dimension of input to be 2, but 
got " << dim1->value;
+            << "AffineGrid expects the second dimension of input to be " << 
spatial << ", but got "
+            << dim1->value;
       }
     }
     if (data_shape->values.size() >= 3) {
       auto* dim2 = data_shape->values[2].as<IntImmNode>();
-      if (dim2 != nullptr && dim2->value != 3) {
+      if (dim2 != nullptr && dim2->value != spatial + 1) {
         TVM_FFI_VISIT_THROW(ValueError, call)
-            << "AffineGrid expects the third dimension of input to be 3, but 
got " << dim2->value;
+            << "AffineGrid expects the third dimension of input to be " << 
spatial + 1
+            << ", but got " << dim2->value;
       }
     }
   }
@@ -429,15 +432,16 @@ Type InferTypeAffineGrid(const Call& call, const 
BlockBuilder& ctx) {
   ffi::Optional<PrimType> out_dtype = data_ty->dtype;
 
   if (data_shape == nullptr || size_value == nullptr) {
-    return TensorType(out_dtype, /*ndim=*/4, data_ty->vdevice);
+    return TensorType(out_dtype, /*ndim=*/spatial + 2, data_ty->vdevice);
   }
 
-  // Output shape: [batch, 2, target_height, target_width]
+  // Output shape: [batch, spatial, *target_spatial_dims].
   ffi::Array<PrimExpr> out_shape;
-  out_shape.push_back(data_shape->values[0]);  // batch
-  out_shape.push_back(IntImm::Int64(2));       // 2 (spatial dimensions)
-  out_shape.push_back(size_value->values[0]);  // target_height
-  out_shape.push_back(size_value->values[1]);  // target_width
+  out_shape.push_back(data_shape->values[0]);   // batch
+  out_shape.push_back(IntImm::Int64(spatial));  // number of spatial 
coordinates
+  for (int i = 0; i < spatial; ++i) {
+    out_shape.push_back(size_value->values[i]);  // target spatial dim
+  }
 
   return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice);
 }
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 972bc48307..b05b8e3742 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -10150,6 +10150,37 @@ def test_affine_grid():
     verify_affine_grid(1, ExpectedAlignCorners)
 
 
+def test_affine_grid_3d():
+    affine_grid_node = helper.make_node(
+        "AffineGrid",
+        ["theta", "size"],
+        ["grid"],
+        align_corners=1,
+    )
+
+    graph = helper.make_graph(
+        [affine_grid_node],
+        "affine_grid_3d_test",
+        inputs=[
+            helper.make_tensor_value_info("theta", TensorProto.FLOAT, [2, 3, 
4]),
+        ],
+        initializer=[
+            helper.make_tensor("size", TensorProto.INT64, [5], [2, 3, 8, 16, 
16]),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("grid", TensorProto.FLOAT, [2, 8, 
16, 16, 3]),
+        ],
+    )
+
+    model = helper.make_model(graph, producer_name="affine_grid_3d_test")
+
+    tvm_model = from_onnx(model, opset=20, keep_params_in_input=True)
+    call_ops = collect_relax_call_ops(tvm_model["main"])
+    assert "relax.image.affine_grid" in call_ops
+    assert "relax.permute_dims" in call_ops
+    assert [int(d) for d in tvm_model["main"].ret_ty.shape] == [2, 8, 16, 16, 
3]
+
+
 @pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
 @pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
 @pytest.mark.parametrize("align_corners", [0, 1])
diff --git a/tests/python/relax/test_transform_legalize_ops_image.py 
b/tests/python/relax/test_transform_legalize_ops_image.py
index c91c4ddb8b..3313f54f53 100644
--- a/tests/python/relax/test_transform_legalize_ops_image.py
+++ b/tests/python/relax/test_transform_legalize_ops_image.py
@@ -126,12 +126,12 @@ def test_image_affine_grid():
             with T.sblock("root"):
                 T.reads()
                 T.writes()
-                for n, dim, i, j in T.grid(T.int64(2), T.int64(2), 
T.int64(16), T.int64(16)):
+                for n, dim, i0, i1 in T.grid(T.int64(2), T.int64(2), 
T.int64(16), T.int64(16)):
                     with T.sblock("compute"):
-                        v_n, v_dim, v_i, v_j = T.axis.remap("SSSS", [n, dim, 
i, j])
+                        v_n, v_dim, v_i0, v_i1 = T.axis.remap("SSSS", [n, dim, 
i0, i1])
                         T.reads(theta[v_n, v_dim, T.int64(0):T.int64(3)])
-                        T.writes(compute[v_n, v_dim, v_i, v_j])
-                        compute[v_n, v_dim, v_i, v_j] = theta[v_n, v_dim, 
T.int64(0)] * (T.float32(-1.0) + T.Cast("float32", v_j) * 
T.float32(0.13333332666666667)) + theta[v_n, v_dim, T.int64(1)] * 
(T.float32(-1.0) + T.Cast("float32", v_i) * T.float32(0.13333332666666667)) + 
theta[v_n, v_dim, T.int64(2)]
+                        T.writes(compute[v_n, v_dim, v_i0, v_i1])
+                        compute[v_n, v_dim, v_i0, v_i1] = theta[v_n, v_dim, 
T.int64(2)] + theta[v_n, v_dim, T.int64(1)] * (T.float32(-1.0) + 
T.Cast("float32", v_i0) * T.float32(0.13333332666666667)) + theta[v_n, v_dim, 
T.int64(0)] * (T.float32(-1.0) + T.Cast("float32", v_i1) * 
T.float32(0.13333332666666667))
     # fmt: on
 
     mod = LegalizeOps()(AffineGrid)

Reply via email to