This is an automated email from the ASF dual-hosted git repository.

guan404ming pushed a commit to branch fix/onnx-affinegrid-3d
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 8f704ef160dddafb68150c3bb332c68ea56afc39
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Jun 19 23:42:10 2026 +0800

    [Relax][ONNX] Support 3D AffineGrid
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 20 +++++++------
 python/tvm/relax/op/image/image.py              | 15 ++++++----
 python/tvm/topi/image/grid_sample.py            | 37 ++++++++++++++----------
 src/relax/op/image/resize.cc                    | 38 ++++++++++++++-----------
 tests/python/relax/test_frontend_onnx.py        | 26 +++++++++++++++++
 5 files changed, 89 insertions(+), 47 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index d550d3bc00..0bd6627a8c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3322,15 +3322,17 @@ class AffineGrid(OnnxOpConverter):
         else:
             raise NotImplementedError(f"Dynamic size of type {type(size)} is 
not supported")
 
-        # Only 2D is supported: size = [N, C, H, W]
-        if len(size_vals) != 4:
-            raise ValueError("Only 2D AffineGrid (size=[N,C,H,W]) is 
supported")
-        target_h, target_w = size_vals[2], size_vals[3]
-
-        # Relax affine_grid outputs [N, 2, H, W]
-        grid = bb.emit(relax.op.image.affine_grid(theta, (target_h, target_w)))
-        # Permute to ONNX convention [N, H, W, 2]
-        return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))
+        if len(size_vals) == 4:
+            # 2D: size = [N, C, H, W]; relax affine_grid outputs [N, 2, H, W].
+            grid = bb.emit(relax.op.image.affine_grid(theta, 
tuple(size_vals[2:])))
+            # Permute to ONNX convention [N, H, W, 2].
+            return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))
+        if len(size_vals) == 5:
+            # 3D: size = [N, C, D, H, W]; relax affine_grid outputs [N, 3, D, 
H, W].
+            grid = bb.emit(relax.op.image.affine_grid(theta, 
tuple(size_vals[2:])))
+            # Permute to ONNX convention [N, D, H, W, 3].
+            return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 4, 1]))
+        raise ValueError("AffineGrid expects size to be [N,C,H,W] (2D) or 
[N,C,D,H,W] (3D)")
 
 
 class Einsum(OnnxOpConverter):
diff --git a/python/tvm/relax/op/image/image.py 
b/python/tvm/relax/op/image/image.py
index 323bfa74b5..09837bfbb2 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -238,7 +238,7 @@ def affine_grid(
     data: Expr,
     size: Expr | SizeLike,
 ) -> Expr:
-    """Generate a 2D sampling grid using an affine transformation matrix.
+    """Generate a 2D or 3D sampling grid using an affine transformation matrix.
 
     This operation is described in https://arxiv.org/pdf/1506.02025.pdf.
     It generates a uniform sampling grid within the target shape, normalizes it
@@ -247,16 +247,19 @@ def affine_grid(
     Parameters
     ----------
     data : relax.Expr
-        The input affine matrix tensor with shape [batch, 2, 3].
+        The input affine matrix tensor with shape [batch, 2, 3] for 2D or
+        [batch, 3, 4] for 3D.
 
-    size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, PrimExprLike]]
-        The target output spatial shape (H, W). If a single integer or PrimExpr
-        is provided, it is interpreted as a square output shape (size, size).
+    size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, ...]]
+        The target output spatial shape, (H, W) for 2D or (D, H, W) for 3D. If 
a
+        single integer or PrimExpr is provided, it is interpreted as a square 
2D
+        output shape (size, size).
 
     Returns
     -------
     result : relax.Expr
-        The output grid tensor with shape [batch, 2, H, W].
+        The output grid tensor with shape [batch, 2, H, W] for 2D or
+        [batch, 3, D, H, W] for 3D.
 
     Note
     ----
diff --git a/python/tvm/topi/image/grid_sample.py 
b/python/tvm/topi/image/grid_sample.py
index 79032f41b3..15f6178cd6 100644
--- a/python/tvm/topi/image/grid_sample.py
+++ b/python/tvm/topi/image/grid_sample.py
@@ -21,7 +21,7 @@ from tvm import te, tirx
 
 
 def affine_grid(data, target_shape):
-    """affine_grid operator that generates 2D sampling grid.
+    """affine_grid operator that generates a 2D or 3D sampling grid.
 
     This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It 
generates a uniform
     sampling grid within the target shape and normalizes it to [-1, 1]. The 
provided affine
@@ -30,31 +30,38 @@ def affine_grid(data, target_shape):
     Parameters
     ----------
     data : tvm.Tensor
-        3-D with shape [batch, 2, 3]. The affine matrix.
+        3-D with shape [batch, 2, 3] for 2D or [batch, 3, 4] for 3D. The 
affine matrix.
 
-    target_shape: list/tuple of two int
-        Specifies the output shape (H, W).
+    target_shape: list/tuple of int
+        Specifies the output spatial shape (H, W) for 2D or (D, H, W) for 3D.
 
     Returns
     -------
     Output : tvm.Tensor
-        4-D with shape [batch, 2, target_height, target_width]
+        [batch, 2, H, W] for 2D or [batch, 3, D, H, W] for 3D.
     """
     assert target_shape is not None
-    assert len(target_shape) == 2
-    assert target_shape[0] > 1 and target_shape[1] > 1, (
-        "target height/width should be greater than 1"
-    )
+    assert len(target_shape) in (2, 3)
+    assert all(s > 1 for s in target_shape), "target spatial dims should be 
greater than 1"
 
     dtype = data.dtype
-    y_step = tirx.const((2.0 - 1e-7) / (target_shape[0] - 1), dtype=dtype)
-    x_step = tirx.const((2.0 - 1e-7) / (target_shape[1] - 1), dtype=dtype)
     start = tirx.const(-1.0, dtype=dtype)
+    steps = [tirx.const((2.0 - 1e-7) / (s - 1), dtype=dtype) for s in 
target_shape]
+
+    if len(target_shape) == 2:
+
+        def _compute(n, dim, i, j):
+            y = start + i * steps[0]
+            x = start + j * steps[1]
+            return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
+
+    else:
 
-    def _compute(n, dim, i, j):
-        y = start + i * y_step
-        x = start + j * x_step
-        return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
+        def _compute(n, dim, k, i, j):
+            z = start + k * steps[0]
+            y = start + i * steps[1]
+            x = start + j * steps[2]
+            return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2] 
* z + data[n, dim, 3]
 
     oshape = (data.shape[0], len(target_shape), *target_shape)
     return te.compute(oshape, _compute, tag="affine_grid")
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index beb89af087..86a8dc0095 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -385,35 +385,38 @@ Type InferTypeAffineGrid(const Call& call, const 
BlockBuilder& ctx) {
         << "AffineGrid expects the target size to be a Shape, while the given 
one is "
         << call->args[1]->GetTypeKey();
   }
-  if (size_ty->ndim != 2) {
+  // 2D output uses a 2-dim spatial size ([H, W]); 3D uses a 3-dim size ([D, 
H, W]).
+  if (size_ty->ndim != 2 && size_ty->ndim != 3) {
     TVM_FFI_VISIT_THROW(ValueError, call)
-        << "AffineGrid expects the target size to be a 2-dim shape, while the 
given "
+        << "AffineGrid expects the target size to be a 2-dim or 3-dim shape, 
while the given "
            "one has ndim "
         << size_ty->ndim;
   }
+  const int spatial = size_ty->ndim;
 
-  // data should be 3-D: [batch, 2, 3]
+  // data should be 3-D: [batch, spatial, spatial + 1] (i.e. [N, 2, 3] or [N, 
3, 4]).
   if (data_ty->ndim != -1 && data_ty->ndim != 3) {
-    TVM_FFI_VISIT_THROW(ValueError, call)
-        << "AffineGrid expects the input data to be 3-D (batch, 2, 3), but got 
ndim "
-        << data_ty->ndim;
+    TVM_FFI_VISIT_THROW(ValueError, call) << "AffineGrid expects the input 
data to be 3-D (batch, "
+                                             "spatial, spatial + 1), but got 
ndim "
+                                          << data_ty->ndim;
   }
 
   const auto* data_shape = data_ty->shape.as<ShapeExprNode>();
   if (data_shape != nullptr) {
-    // Check that the affine matrix has shape [batch, 2, 3]
     if (data_shape->values.size() >= 2) {
       auto* dim1 = data_shape->values[1].as<IntImmNode>();
-      if (dim1 != nullptr && dim1->value != 2) {
+      if (dim1 != nullptr && dim1->value != spatial) {
         TVM_FFI_VISIT_THROW(ValueError, call)
-            << "AffineGrid expects the second dimension of input to be 2, but 
got " << dim1->value;
+            << "AffineGrid expects the second dimension of input to be " << 
spatial << ", but got "
+            << dim1->value;
       }
     }
     if (data_shape->values.size() >= 3) {
       auto* dim2 = data_shape->values[2].as<IntImmNode>();
-      if (dim2 != nullptr && dim2->value != 3) {
+      if (dim2 != nullptr && dim2->value != spatial + 1) {
         TVM_FFI_VISIT_THROW(ValueError, call)
-            << "AffineGrid expects the third dimension of input to be 3, but 
got " << dim2->value;
+            << "AffineGrid expects the third dimension of input to be " << 
spatial + 1
+            << ", but got " << dim2->value;
       }
     }
   }
@@ -421,15 +424,16 @@ Type InferTypeAffineGrid(const Call& call, const 
BlockBuilder& ctx) {
   DataType out_dtype = data_ty->dtype;
 
   if (data_shape == nullptr || size_value == nullptr) {
-    return TensorType(out_dtype, /*ndim=*/4, data_ty->vdevice);
+    return TensorType(out_dtype, /*ndim=*/spatial + 2, data_ty->vdevice);
   }
 
-  // Output shape: [batch, 2, target_height, target_width]
+  // Output shape: [batch, spatial, *target_spatial_dims].
   ffi::Array<PrimExpr> out_shape;
-  out_shape.push_back(data_shape->values[0]);  // batch
-  out_shape.push_back(IntImm::Int64(2));       // 2 (spatial dimensions)
-  out_shape.push_back(size_value->values[0]);  // target_height
-  out_shape.push_back(size_value->values[1]);  // target_width
+  out_shape.push_back(data_shape->values[0]);   // batch
+  out_shape.push_back(IntImm::Int64(spatial));  // number of spatial 
coordinates
+  for (int i = 0; i < spatial; ++i) {
+    out_shape.push_back(size_value->values[i]);  // target spatial dim
+  }
 
   return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice);
 }
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index d3036a2547..b47bd1c43b 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -5581,6 +5581,32 @@ def test_affine_grid():
     check_correctness(model, opset=20)
 
 
+def test_affine_grid_3d():
+    affine_grid_node = helper.make_node(
+        "AffineGrid",
+        ["theta", "size"],
+        ["grid"],
+        align_corners=1,
+    )
+
+    graph = helper.make_graph(
+        [affine_grid_node],
+        "affine_grid_3d_test",
+        inputs=[
+            helper.make_tensor_value_info("theta", TensorProto.FLOAT, [2, 3, 
4]),
+        ],
+        initializer=[
+            helper.make_tensor("size", TensorProto.INT64, [5], [2, 3, 8, 16, 
16]),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("grid", TensorProto.FLOAT, [2, 8, 
16, 16, 3]),
+        ],
+    )
+
+    model = helper.make_model(graph, producer_name="affine_grid_3d_test")
+    check_correctness(model, opset=20)
+
+
 @pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
 @pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
 @pytest.mark.parametrize("align_corners", [0, 1])

Reply via email to