This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6383c7fd7f [Relax][ONNX] Support 3D AffineGrid (#19863)
6383c7fd7f is described below
commit 6383c7fd7f9edbd73e8d81a0b0f39d19e3b0e331
Author: Guan-Ming Chiu <[email protected]>
AuthorDate: Wed Jul 1 03:35:22 2026 +0800
[Relax][ONNX] Support 3D AffineGrid (#19863)
## Related Issue
closes #19689
## Why
The Relax AffineGrid op only handled 2D (4D theta/grid); 5D 3D inputs
from ONNX failed.
## How
- Generalize struct-info inference to 2D/3D via spatial =
size_sinfo->ndim.
- Branch TOPI affine_grid compute on 2D vs 3D.
- Add the 3D permute path in the frontend and a test_affine_grid_3d
case.
---------
Signed-off-by: Guan-Ming (Wesley) Chiu
<[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 17 ++++----
python/tvm/relax/op/image/image.py | 15 ++++---
python/tvm/topi/image/grid_sample.py | 46 +++++++++++-----------
src/relax/op/image/resize.cc | 38 ++++++++++--------
tests/python/relax/test_frontend_onnx.py | 31 +++++++++++++++
.../relax/test_transform_legalize_ops_image.py | 8 ++--
6 files changed, 96 insertions(+), 59 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 87675a3f9c..7a0f25bf83 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3342,15 +3342,14 @@ class AffineGrid(OnnxOpConverter):
else:
raise NotImplementedError(f"Dynamic size of type {type(size)} is
not supported")
- # Only 2D is supported: size = [N, C, H, W]
- if len(size_vals) != 4:
- raise ValueError("Only 2D AffineGrid (size=[N,C,H,W]) is
supported")
- target_h, target_w = size_vals[2], size_vals[3]
-
- # Relax affine_grid outputs [N, 2, H, W]
- grid = bb.emit(relax.op.image.affine_grid(theta, (target_h, target_w),
align_corners))
- # Permute to ONNX convention [N, H, W, 2]
- return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))
+ if len(size_vals) not in (4, 5):
+ raise ValueError("AffineGrid expects size to be [N,C,H,W] (2D) or
[N,C,D,H,W] (3D)")
+
+ # relax affine_grid outputs [N, spatial, *spatial_dims]; move the
coord axis
+ # last to match the ONNX convention [N, *spatial_dims, spatial].
+ grid = bb.emit(relax.op.image.affine_grid(theta, tuple(size_vals[2:]),
align_corners))
+ axes = [0, *range(2, len(size_vals)), 1]
+ return bb.emit(relax.op.permute_dims(grid, axes=axes))
class Einsum(OnnxOpConverter):
diff --git a/python/tvm/relax/op/image/image.py
b/python/tvm/relax/op/image/image.py
index 29aa8457d2..91d7468808 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -239,7 +239,7 @@ def affine_grid(
size: Expr | SizeLike,
align_corners: bool = True,
) -> Expr:
- """Generate a 2D sampling grid using an affine transformation matrix.
+ """Generate a 2D or 3D sampling grid using an affine transformation matrix.
This operation is described in https://arxiv.org/pdf/1506.02025.pdf.
It generates a uniform sampling grid within the target shape, normalizes it
@@ -248,11 +248,13 @@ def affine_grid(
Parameters
----------
data : relax.Expr
- The input affine matrix tensor with shape [batch, 2, 3].
+ The input affine matrix tensor with shape [batch, 2, 3] for 2D or
+ [batch, 3, 4] for 3D.
- size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, PrimExprLike]]
- The target output spatial shape (H, W). If a single integer or PrimExpr
- is provided, it is interpreted as a square output shape (size, size).
+ size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, ...]]
+ The target output spatial shape, (H, W) for 2D or (D, H, W) for 3D. If
a
+ single integer or PrimExpr is provided, it is interpreted as a square
2D
+ output shape (size, size).
align_corners : bool
If True, normalized grid coordinates map to corner pixels; if False, to
@@ -261,7 +263,8 @@ def affine_grid(
Returns
-------
result : relax.Expr
- The output grid tensor with shape [batch, 2, H, W].
+ The output grid tensor with shape [batch, 2, H, W] for 2D or
+ [batch, 3, D, H, W] for 3D.
"""
if isinstance(size, int | PrimExpr):
size = (size, size)
diff --git a/python/tvm/topi/image/grid_sample.py
b/python/tvm/topi/image/grid_sample.py
index cdfb7f4362..b5058c6d4a 100644
--- a/python/tvm/topi/image/grid_sample.py
+++ b/python/tvm/topi/image/grid_sample.py
@@ -21,7 +21,7 @@ from tvm import te, tirx
def affine_grid(data, target_shape, align_corners=True):
- """affine_grid operator that generates 2D sampling grid.
+ """affine_grid operator that generates a 2D or 3D sampling grid.
This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It
generates a uniform
sampling grid within the target shape and normalizes it to [-1, 1]. The
provided affine
@@ -30,10 +30,10 @@ def affine_grid(data, target_shape, align_corners=True):
Parameters
----------
data : tvm.Tensor
- 3-D with shape [batch, 2, 3]. The affine matrix.
+ 3-D with shape [batch, 2, 3] for 2D or [batch, 3, 4] for 3D. The
affine matrix.
- target_shape: list/tuple of two int
- Specifies the output shape (H, W).
+ target_shape: list/tuple of int
+ Specifies the output spatial shape (H, W) for 2D or (D, H, W) for 3D.
align_corners : bool
If True, normalized coordinates map to corner pixels; if False, to
pixel centers
@@ -42,35 +42,35 @@ def affine_grid(data, target_shape, align_corners=True):
Returns
-------
Output : tvm.Tensor
- 4-D with shape [batch, 2, target_height, target_width]
+ [batch, 2, H, W] for 2D or [batch, 3, D, H, W] for 3D.
"""
- assert target_shape is not None
- assert len(target_shape) == 2
+ assert len(target_shape) in (2, 3)
if align_corners:
- assert target_shape[0] > 1 and target_shape[1] > 1, (
- "target height/width should be greater than 1 when align_corners
is True"
+ assert all(s > 1 for s in target_shape), (
+ "target spatial dims should be greater than 1 when align_corners
is True"
)
dtype = data.dtype
- height, width = target_shape[0], target_shape[1]
if align_corners:
- y_step = tirx.const((2.0 - 1e-7) / (height - 1), dtype=dtype)
- x_step = tirx.const((2.0 - 1e-7) / (width - 1), dtype=dtype)
- y_start = tirx.const(-1.0, dtype=dtype)
- x_start = tirx.const(-1.0, dtype=dtype)
+ starts = [tirx.const(-1.0, dtype=dtype) for _ in target_shape]
+ steps = [tirx.const((2.0 - 1e-7) / (s - 1), dtype=dtype) for s in
target_shape]
else:
# Pixel centers: coordinate i maps to (2 * i + 1) / size - 1.
- y_step = tirx.const(2.0 / height, dtype=dtype)
- x_step = tirx.const(2.0 / width, dtype=dtype)
- y_start = tirx.const(-1.0 + 1.0 / height, dtype=dtype)
- x_start = tirx.const(-1.0 + 1.0 / width, dtype=dtype)
+ starts = [tirx.const(-1.0 + 1.0 / s, dtype=dtype) for s in
target_shape]
+ steps = [tirx.const(2.0 / s, dtype=dtype) for s in target_shape]
- def _compute(n, dim, i, j):
- y = y_start + i * y_step
- x = x_start + j * x_step
- return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
+ ndim = len(target_shape)
- oshape = (data.shape[0], len(target_shape), *target_shape)
+ def _compute(n, dim, *coords):
+ # coords are ordered slowest-to-fastest (e.g. (k, i, j)); the affine
matrix
+ # columns are fastest-to-slowest (x, y, z), so index it in reverse.
+ val = data[n, dim, ndim] # translation column
+ for r in range(ndim):
+ coord = starts[r] + coords[r] * steps[r]
+ val += data[n, dim, ndim - 1 - r] * coord
+ return val
+
+ oshape = (data.shape[0], ndim, *target_shape)
return te.compute(oshape, _compute, tag="affine_grid")
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 7dd80bf22a..dcd36465b6 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -393,35 +393,38 @@ Type InferTypeAffineGrid(const Call& call, const
BlockBuilder& ctx) {
<< "AffineGrid expects the target size to be a Shape, while the given
one is "
<< call->args[1]->GetTypeKey();
}
- if (size_ty->ndim != 2) {
+ // 2D output uses a 2-dim spatial size ([H, W]); 3D uses a 3-dim size ([D,
H, W]).
+ if (size_ty->ndim != 2 && size_ty->ndim != 3) {
TVM_FFI_VISIT_THROW(ValueError, call)
- << "AffineGrid expects the target size to be a 2-dim shape, while the
given "
+ << "AffineGrid expects the target size to be a 2-dim or 3-dim shape,
while the given "
"one has ndim "
<< size_ty->ndim;
}
+ const int spatial = size_ty->ndim;
- // data should be 3-D: [batch, 2, 3]
+ // data should be 3-D: [batch, spatial, spatial + 1] (i.e. [N, 2, 3] or [N,
3, 4]).
if (data_ty->ndim != -1 && data_ty->ndim != 3) {
- TVM_FFI_VISIT_THROW(ValueError, call)
- << "AffineGrid expects the input data to be 3-D (batch, 2, 3), but got
ndim "
- << data_ty->ndim;
+ TVM_FFI_VISIT_THROW(ValueError, call) << "AffineGrid expects the input
data to be 3-D (batch, "
+ "spatial, spatial + 1), but got
ndim "
+ << data_ty->ndim;
}
const auto* data_shape = data_ty->shape.as<ShapeExprNode>();
if (data_shape != nullptr) {
- // Check that the affine matrix has shape [batch, 2, 3]
if (data_shape->values.size() >= 2) {
auto* dim1 = data_shape->values[1].as<IntImmNode>();
- if (dim1 != nullptr && dim1->value != 2) {
+ if (dim1 != nullptr && dim1->value != spatial) {
TVM_FFI_VISIT_THROW(ValueError, call)
- << "AffineGrid expects the second dimension of input to be 2, but
got " << dim1->value;
+ << "AffineGrid expects the second dimension of input to be " <<
spatial << ", but got "
+ << dim1->value;
}
}
if (data_shape->values.size() >= 3) {
auto* dim2 = data_shape->values[2].as<IntImmNode>();
- if (dim2 != nullptr && dim2->value != 3) {
+ if (dim2 != nullptr && dim2->value != spatial + 1) {
TVM_FFI_VISIT_THROW(ValueError, call)
- << "AffineGrid expects the third dimension of input to be 3, but
got " << dim2->value;
+ << "AffineGrid expects the third dimension of input to be " <<
spatial + 1
+ << ", but got " << dim2->value;
}
}
}
@@ -429,15 +432,16 @@ Type InferTypeAffineGrid(const Call& call, const
BlockBuilder& ctx) {
ffi::Optional<PrimType> out_dtype = data_ty->dtype;
if (data_shape == nullptr || size_value == nullptr) {
- return TensorType(out_dtype, /*ndim=*/4, data_ty->vdevice);
+ return TensorType(out_dtype, /*ndim=*/spatial + 2, data_ty->vdevice);
}
- // Output shape: [batch, 2, target_height, target_width]
+ // Output shape: [batch, spatial, *target_spatial_dims].
ffi::Array<PrimExpr> out_shape;
- out_shape.push_back(data_shape->values[0]); // batch
- out_shape.push_back(IntImm::Int64(2)); // 2 (spatial dimensions)
- out_shape.push_back(size_value->values[0]); // target_height
- out_shape.push_back(size_value->values[1]); // target_width
+ out_shape.push_back(data_shape->values[0]); // batch
+ out_shape.push_back(IntImm::Int64(spatial)); // number of spatial
coordinates
+ for (int i = 0; i < spatial; ++i) {
+ out_shape.push_back(size_value->values[i]); // target spatial dim
+ }
return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice);
}
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 972bc48307..b05b8e3742 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -10150,6 +10150,37 @@ def test_affine_grid():
verify_affine_grid(1, ExpectedAlignCorners)
+def test_affine_grid_3d():
+ affine_grid_node = helper.make_node(
+ "AffineGrid",
+ ["theta", "size"],
+ ["grid"],
+ align_corners=1,
+ )
+
+ graph = helper.make_graph(
+ [affine_grid_node],
+ "affine_grid_3d_test",
+ inputs=[
+ helper.make_tensor_value_info("theta", TensorProto.FLOAT, [2, 3,
4]),
+ ],
+ initializer=[
+ helper.make_tensor("size", TensorProto.INT64, [5], [2, 3, 8, 16,
16]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("grid", TensorProto.FLOAT, [2, 8,
16, 16, 3]),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="affine_grid_3d_test")
+
+ tvm_model = from_onnx(model, opset=20, keep_params_in_input=True)
+ call_ops = collect_relax_call_ops(tvm_model["main"])
+ assert "relax.image.affine_grid" in call_ops
+ assert "relax.permute_dims" in call_ops
+ assert [int(d) for d in tvm_model["main"].ret_ty.shape] == [2, 8, 16, 16,
3]
+
+
@pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
@pytest.mark.parametrize("align_corners", [0, 1])
diff --git a/tests/python/relax/test_transform_legalize_ops_image.py
b/tests/python/relax/test_transform_legalize_ops_image.py
index c91c4ddb8b..3313f54f53 100644
--- a/tests/python/relax/test_transform_legalize_ops_image.py
+++ b/tests/python/relax/test_transform_legalize_ops_image.py
@@ -126,12 +126,12 @@ def test_image_affine_grid():
with T.sblock("root"):
T.reads()
T.writes()
- for n, dim, i, j in T.grid(T.int64(2), T.int64(2),
T.int64(16), T.int64(16)):
+ for n, dim, i0, i1 in T.grid(T.int64(2), T.int64(2),
T.int64(16), T.int64(16)):
with T.sblock("compute"):
- v_n, v_dim, v_i, v_j = T.axis.remap("SSSS", [n, dim,
i, j])
+ v_n, v_dim, v_i0, v_i1 = T.axis.remap("SSSS", [n, dim,
i0, i1])
T.reads(theta[v_n, v_dim, T.int64(0):T.int64(3)])
- T.writes(compute[v_n, v_dim, v_i, v_j])
- compute[v_n, v_dim, v_i, v_j] = theta[v_n, v_dim,
T.int64(0)] * (T.float32(-1.0) + T.Cast("float32", v_j) *
T.float32(0.13333332666666667)) + theta[v_n, v_dim, T.int64(1)] *
(T.float32(-1.0) + T.Cast("float32", v_i) * T.float32(0.13333332666666667)) +
theta[v_n, v_dim, T.int64(2)]
+ T.writes(compute[v_n, v_dim, v_i0, v_i1])
+ compute[v_n, v_dim, v_i0, v_i1] = theta[v_n, v_dim,
T.int64(2)] + theta[v_n, v_dim, T.int64(1)] * (T.float32(-1.0) +
T.Cast("float32", v_i0) * T.float32(0.13333332666666667)) + theta[v_n, v_dim,
T.int64(0)] * (T.float32(-1.0) + T.Cast("float32", v_i1) *
T.float32(0.13333332666666667))
# fmt: on
mod = LegalizeOps()(AffineGrid)