This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2f2469e637 [Relax] Add affine_grid operator with PyTorch and ONNX
frontend support (#18933)
2f2469e637 is described below
commit 2f2469e6371dd4c9f89cba3924d877d090230861
Author: HoYi <[email protected]>
AuthorDate: Fri Mar 27 10:38:10 2026 +0800
[Relax] Add affine_grid operator with PyTorch and ONNX frontend support
(#18933)
## Summary
Add `relax.image.affine_grid` operator for Spatial Transformer Networks,
along with PyTorch and ONNX frontend integration.
TOPI compute (`topi.image.affine_grid`) already exists. This PR
completes the Relax-level registration and frontend support, following
the existing `resize2d` / `grid_sample` pattern.
## Changes
**Relax op registration:**
- C++ op function, FFI registration, and struct info inference
(`resize.h`, `resize.cc`)
- Python wrapper with flexible size parameter handling (`image.py`)
- Legalization to `topi.image.affine_grid` with `PrimExpr` → `int`
conversion
- Op-level tests (struct info inference + e2e numerical correctness) and
legalization test
**PyTorch frontend:**
- Converter for `aten.affine_grid_generator.default`
- Layout conversion from TVM `[N,2,H,W]` to PyTorch `[N,H,W,2]` via
`permute_dims`
- Single-kernel path is 5.6x faster than the decomposed path (30+ ops)
- Structural IR test and numerical correctness test
**ONNX frontend:**
- `AffineGrid` converter with `_impl_v20` (opset 20, when the op was
first introduced)
- Support for constant size tensor `[N,C,H,W]`
- Layout conversion from TVM `[N,2,H,W]` to ONNX `[N,H,W,2]`
- End-to-end correctness test against ONNX Runtime
## Limitations
- Only `align_corners=True` is supported (matches current TOPI
implementation)
- Only 2D affine grid is supported
## Validation
```bash
pytest tests/python/relax/test_op_image.py -k "affine_grid" -v #
8 passed
pytest tests/python/relax/test_transform_legalize_ops_image.py -k
"affine_grid" -v # 1 passed
pytest tests/python/relax/test_frontend_from_exported_program.py -k
"affine_grid" -v # 2 passed
pytest tests/python/relax/test_frontend_onnx.py -k "affine_grid" -v # 1
passed
```
All 12 tests passed.
---------
Co-authored-by: Claude Opus 4.6 <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 34 ++++++
.../frontend/torch/exported_program_translator.py | 24 ++++
python/tvm/relax/op/image/__init__.py | 2 +-
python/tvm/relax/op/image/image.py | 41 +++++++
python/tvm/relax/transform/legalize_ops/image.py | 18 ++-
src/relax/op/image/resize.cc | 94 +++++++++++++++
src/relax/op/image/resize.h | 3 +
.../relax/test_frontend_from_exported_program.py | 57 +++++++++
tests/python/relax/test_frontend_onnx.py | 26 ++++
tests/python/relax/test_op_image.py | 131 +++++++++++++++++++++
.../relax/test_transform_legalize_ops_image.py | 36 ++++++
11 files changed, 464 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index c8d4c469fc..a117317125 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2411,6 +2411,39 @@ class Resize(OnnxOpConverter):
)
+class AffineGrid(OnnxOpConverter):
+ """Converts an onnx AffineGrid node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v20(cls, bb, inputs, attr, params):
+ theta = inputs[0] # [N, 2, 3] for 2D
+ size = get_constant(inputs[1], params) # [N, C, H, W] for 2D
+ align_corners = attr.get("align_corners", 0)
+
+ if align_corners != 1:
+ raise NotImplementedError(
+ "AffineGrid with align_corners=0 is not yet supported in TVM"
+ )
+
+ # Extract size values
+ if isinstance(size, relax.Constant):
+ size_vals = size.data.numpy().astype("int64").tolist()
+ elif isinstance(size, relax.expr.ShapeExpr):
+ size_vals = [int(v.value) for v in size.values]
+ else:
+ raise NotImplementedError(f"Dynamic size of type {type(size)} is
not supported")
+
+ # Only 2D is supported: size = [N, C, H, W]
+ if len(size_vals) != 4:
+ raise ValueError("Only 2D AffineGrid (size=[N,C,H,W]) is
supported")
+ target_h, target_w = size_vals[2], size_vals[3]
+
+ # Relax affine_grid outputs [N, 2, H, W]
+ grid = bb.emit(relax.op.image.affine_grid(theta, (target_h, target_w)))
+ # Permute to ONNX convention [N, H, W, 2]
+ return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))
+
+
class Einsum(OnnxOpConverter):
"""Converts an onnx Einsum node into an equivalent Relax expression."""
@@ -4151,6 +4184,7 @@ def _get_convert_map():
"NonMaxSuppression": NonMaxSuppression,
"AllClassNMS": AllClassNMS,
"GridSample": GridSample,
+ "AffineGrid": AffineGrid,
"Upsample": Upsample,
# others
"DepthToSpace": DepthToSpace,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 47633c69b5..cc37554bf3 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1123,6 +1123,29 @@ class ExportedProgramImporter(BaseFXGraphImporter):
)
)
+ def _affine_grid_generator(self, node: fx.Node) -> relax.Var:
+ """Convert torch.nn.functional.affine_grid to
relax.op.image.affine_grid."""
+ args = self.retrieve_args(node)
+ theta = args[0] # [N, 2, 3]
+ size = args[1] # [N, C, H, W]
+ align_corners = args[2] if len(args) > 2 else False
+
+ if not align_corners:
+ raise NotImplementedError(
+ "affine_grid with align_corners=False is not yet supported in
TVM"
+ )
+
+ # Extract spatial dimensions (H, W) from PyTorch's [N, C, H, W] size
+ target_h = size[2]
+ target_w = size[3]
+
+ # Relax affine_grid outputs [N, 2, H, W]
+ grid = self.block_builder.emit(
+ relax.op.image.affine_grid(theta, (target_h, target_w))
+ )
+ # Permute to PyTorch convention [N, H, W, 2]
+ return self.block_builder.emit(relax.op.permute_dims(grid, axes=[0, 2,
3, 1]))
+
def _torchvision_roi_align(self, node: fx.Node) -> relax.Var:
"""Convert torchvision.ops.roi_align to relax.op.vision.roi_align."""
args = self.retrieve_args(node)
@@ -1768,6 +1791,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"zeros.default": self._zeros,
"zeros_like.default": self._zeros_like,
"grid_sampler_2d.default": self._grid_sampler_2d,
+ "affine_grid_generator.default": self._affine_grid_generator,
"roi_align.default": self._torchvision_roi_align,
# datatype
"to.dtype": self._to,
diff --git a/python/tvm/relax/op/image/__init__.py
b/python/tvm/relax/op/image/__init__.py
index 6b02c32199..dcc0d1f883 100644
--- a/python/tvm/relax/op/image/__init__.py
+++ b/python/tvm/relax/op/image/__init__.py
@@ -17,4 +17,4 @@
# under the License.
"""Image operators."""
-from .image import grid_sample, resize2d, resize3d
+from .image import affine_grid, grid_sample, resize2d, resize3d
diff --git a/python/tvm/relax/op/image/image.py
b/python/tvm/relax/op/image/image.py
index b267f40709..323bfa74b5 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -16,6 +16,8 @@
# under the License.
"""Image operators."""
+from typing import cast
+
from tvm import DataType
from tvm.ir.expr import PrimExpr
@@ -23,6 +25,7 @@ from ...expr import Expr, ShapeExpr
from . import _ffi_api
PrimExprLike = int | PrimExpr
+SizeLike = PrimExprLike | tuple[PrimExprLike, ...]
def resize2d(
@@ -229,3 +232,41 @@ def grid_sample(
padding_mode,
align_corners,
)
+
+
+def affine_grid(
+ data: Expr,
+ size: Expr | SizeLike,
+) -> Expr:
+ """Generate a 2D sampling grid using an affine transformation matrix.
+
+ This operation is described in https://arxiv.org/pdf/1506.02025.pdf.
+ It generates a uniform sampling grid within the target shape, normalizes it
+ to [-1, 1], and applies the provided affine transformation.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input affine matrix tensor with shape [batch, 2, 3].
+
+ size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, PrimExprLike]]
+ The target output spatial shape (H, W). If a single integer or PrimExpr
+ is provided, it is interpreted as a square output shape (size, size).
+
+ Returns
+ -------
+ result : relax.Expr
+ The output grid tensor with shape [batch, 2, H, W].
+
+ Note
+ ----
+ Only `align_corners=True` is supported by this operator, matching the
+ behavior of the underlying TOPI implementation. When using this operator
+ via PyTorch or ONNX frontends, `align_corners=False` will be rejected.
+ """
+ if isinstance(size, int | PrimExpr):
+ size = (size, size)
+ if isinstance(size, tuple | list):
+ size = ShapeExpr(size)
+
+ return cast(Expr, _ffi_api.affine_grid(data, size))
diff --git a/python/tvm/relax/transform/legalize_ops/image.py
b/python/tvm/relax/transform/legalize_ops/image.py
index 1e7aaebceb..19431a2731 100644
--- a/python/tvm/relax/transform/legalize_ops/image.py
+++ b/python/tvm/relax/transform/legalize_ops/image.py
@@ -17,7 +17,7 @@
# pylint: disable=invalid-name
"""Default legalization function for image operators."""
-from tvm import topi
+from tvm import tirx, topi
from ...block_builder import BlockBuilder
from ...expr import Call, Expr
@@ -54,6 +54,22 @@ def _image_grid_sample(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.image.affine_grid")
+def _image_affine_grid(bb: BlockBuilder, call: Call) -> Expr:
+ for v in call.args[1].values:
+ if not isinstance(v, (int, tirx.IntImm)):
+ raise ValueError(
+ "affine_grid legalization requires static target_shape, "
+ f"got symbolic value: {v}"
+ )
+ target_shape = [int(v) for v in call.args[1].values]
+ return bb.call_te(
+ topi.image.affine_grid,
+ call.args[0],
+ target_shape=target_shape,
+ )
+
+
@register_legalize("relax.image.resize3d")
def _image_resize3d(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index fe9df47dd5..ba7de8115e 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -340,5 +340,99 @@ TVM_REGISTER_OP("relax.image.grid_sample")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.image.affine_grid */
+
+Expr affine_grid(Expr data, Expr size) {
+ static const Op& op = Op::Get("relax.image.affine_grid");
+ return Call(op, {std::move(data), std::move(size)}, Attrs(), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.image.affine_grid", affine_grid);
+}
+
+StructInfo InferStructInfoAffineGrid(const Call& call, const BlockBuilder&
ctx) {
+ if (call->args.size() != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "AffineGrid expects two arguments, while the given
number of arguments is "
+ << call->args.size());
+ }
+
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* size_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
+ const auto* size_value = call->args[1].as<ShapeExprNode>();
+
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "AffineGrid expects the input data to be a Tensor, while the given
data is "
+ << call->args[0]->GetTypeKey());
+ }
+ if (size_sinfo == nullptr) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "AffineGrid expects the target size to be a Shape, while the given
one is "
+ << call->args[1]->GetTypeKey());
+ }
+ if (size_sinfo->ndim != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "AffineGrid expects the target size to be a 2-dim
shape, while the given "
+ "one has ndim "
+ << size_sinfo->ndim);
+ }
+
+ // data should be 3-D: [batch, 2, 3]
+ if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "AffineGrid expects the input data to be 3-D (batch,
2, 3), but got ndim "
+ << data_sinfo->ndim);
+ }
+
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+ if (data_shape != nullptr) {
+ // Check that the affine matrix has shape [batch, 2, 3]
+ if (data_shape->values.size() >= 2) {
+ auto* dim1 = data_shape->values[1].as<IntImmNode>();
+ if (dim1 != nullptr && dim1->value != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "AffineGrid expects the second dimension of input
to be 2, but got "
+ << dim1->value);
+ }
+ }
+ if (data_shape->values.size() >= 3) {
+ auto* dim2 = data_shape->values[2].as<IntImmNode>();
+ if (dim2 != nullptr && dim2->value != 3) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "AffineGrid expects the third dimension of input
to be 3, but got "
+ << dim2->value);
+ }
+ }
+ }
+
+ DataType out_dtype = data_sinfo->dtype;
+
+ if (data_shape == nullptr || size_value == nullptr) {
+ return TensorStructInfo(out_dtype, /*ndim=*/4, data_sinfo->vdevice);
+ }
+
+ // Output shape: [batch, 2, target_height, target_width]
+ ffi::Array<PrimExpr> out_shape;
+ out_shape.push_back(data_shape->values[0]); // batch
+ out_shape.push_back(IntImm(DataType::Int(64), 2)); // 2 (spatial dimensions)
+ out_shape.push_back(size_value->values[0]); // target_height
+ out_shape.push_back(size_value->values[1]); // target_width
+
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype,
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.image.affine_grid")
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input affine matrix tensor.")
+ .add_argument("size", "Shape", "The target output shape (H, W).")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAffineGrid)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h
index c769cf91f5..06a927d3a7 100644
--- a/src/relax/op/image/resize.h
+++ b/src/relax/op/image/resize.h
@@ -48,6 +48,9 @@ Expr resize3d(Expr data, Expr size, ffi::Array<FloatImm> roi,
ffi::String layout
Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout,
ffi::String padding_mode, bool align_corners);
+/*! \brief Image affine_grid operator. */
+Expr affine_grid(Expr data, Expr size);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 7a3548b4cf..6029499372 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -9095,5 +9095,62 @@ def test_cond_nested():
)
+def test_affine_grid():
+ class AffineGrid(Module):
+ def forward(self, theta):
+ return torch.nn.functional.affine_grid(
+ theta, [1, 3, 16, 16], align_corners=True
+ )
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ theta: R.Tensor((1, 2, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 16, 16, 2), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 16, 16), dtype="float32") =
R.image.affine_grid(
+ theta, size=(16, 16)
+ )
+ lv1: R.Tensor((1, 16, 16, 2), dtype="float32") =
R.permute_dims(
+ lv, axes=[0, 2, 3, 1]
+ )
+ gv: R.Tuple(R.Tensor((1, 16, 16, 2), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 2, 3, dtype=torch.float32),)
+ # Disable decomposition to keep aten.affine_grid_generator as a single op
+ verify_model(AffineGrid(), example_args, {}, expected,
run_ep_decomposition=False)
+
+
+def test_affine_grid_numerically():
+ """Verify affine_grid numerical correctness: PyTorch vs TVM via our
converter."""
+
+ class AffineGrid(Module):
+ def forward(self, theta):
+ return torch.nn.functional.affine_grid(
+ theta, [2, 3, 8, 12], align_corners=True
+ )
+
+ model = AffineGrid()
+ example_args = (torch.randn(2, 2, 3, dtype=torch.float32),)
+
+ with torch.no_grad():
+ pytorch_output = model(*example_args)
+
+ exported_program = export(model, args=example_args)
+ mod = from_exported_program(exported_program, run_ep_decomposition=False)
+
+ exe = tvm.compile(mod, target="llvm")
+ vm = relax.VirtualMachine(exe, tvm.cpu())
+
+ tvm_args = [tvm.runtime.tensor(arg.numpy()) for arg in example_args]
+ tvm_output = vm["main"](*tvm_args)
+ tvm_output_np = tvm_output[0].numpy()
+
+ tvm.testing.assert_allclose(tvm_output_np, pytorch_output.numpy(),
rtol=1e-5, atol=1e-5)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 86fa533874..887533f261 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4011,6 +4011,32 @@ def test_nms_score_threshold():
)
+def test_affine_grid():
+ affine_grid_node = helper.make_node(
+ "AffineGrid",
+ ["theta", "size"],
+ ["grid"],
+ align_corners=1,
+ )
+
+ graph = helper.make_graph(
+ [affine_grid_node],
+ "affine_grid_test",
+ inputs=[
+ helper.make_tensor_value_info("theta", TensorProto.FLOAT, [2, 2,
3]),
+ ],
+ initializer=[
+ helper.make_tensor("size", TensorProto.INT64, [4], [2, 3, 16, 16]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("grid", TensorProto.FLOAT, [2, 16,
16, 2]),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="affine_grid_test")
+ check_correctness(model, opset=20)
+
+
@pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
@pytest.mark.parametrize("align_corners", [0, 1])
diff --git a/tests/python/relax/test_op_image.py
b/tests/python/relax/test_op_image.py
index 6650fc359b..3009b9414a 100644
--- a/tests/python/relax/test_op_image.py
+++ b/tests/python/relax/test_op_image.py
@@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import numpy as np
import pytest
import tvm
import tvm.testing
+import tvm.topi.testing
from tvm import TVMError, relax, tirx
from tvm.ir import Op, VDevice
from tvm.script import relax as R
@@ -26,6 +28,8 @@ from tvm.script import relax as R
def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
assert relax.op.image.resize2d(x, (28, 28)).op ==
Op.get("relax.image.resize2d")
+ theta = relax.Var("theta", R.Tensor((2, 2, 3), "float32"))
+ assert relax.op.image.affine_grid(theta, (16, 16)).op ==
Op.get("relax.image.affine_grid")
y = relax.Var("y", R.Tensor((2, 3, 8, 16, 32), "float32"))
assert relax.op.image.resize3d(y, (4, 8, 12)).op ==
Op.get("relax.image.resize3d")
@@ -356,5 +360,132 @@ def test_resize2d_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.image.resize2d(x2, s0))
+def test_affine_grid_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 2, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 2, 3), "float32", vdev0))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x3 = relax.Var("x", R.Tensor("float32"))
+ x4 = relax.Var("x", R.Tensor(ndim=3))
+
+ _check_inference(
+ bb,
+ relax.op.image.affine_grid(x0, (16, 16)),
+ relax.TensorStructInfo((2, 2, 16, 16), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.affine_grid(x1, (16, 16)),
+ relax.TensorStructInfo((2, 2, 16, 16), "float32", vdev0),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.affine_grid(x0, size=16),
+ relax.TensorStructInfo((2, 2, 16, 16), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.affine_grid(x0, size=(16, 20)),
+ relax.TensorStructInfo((2, 2, 16, 20), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.affine_grid(x2, size=(16, 16)),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.affine_grid(x3, size=(16, 16)),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.affine_grid(x4, size=(16, 16)),
+ relax.TensorStructInfo(dtype="", ndim=4),
+ )
+
+
+def test_affine_grid_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ n = tirx.Var("n", "int64")
+ oh = tirx.Var("oh", "int64")
+ ow = tirx.Var("ow", "int64")
+ x0 = relax.Var("x", R.Tensor((n, 2, 3), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.image.affine_grid(x0, size=(oh, ow)),
+ relax.TensorStructInfo((n, 2, oh, ow), "float32"),
+ )
+
+
+def test_affine_grid_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", relax.ShapeStructInfo((2, 2, 3)))
+ x1 = relax.Var("x", R.Tensor((2, 2, 3), "float32"))
+ s0 = relax.Var("s", R.Tensor((3, 3)))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.affine_grid(x0, size=(16, 16)))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.affine_grid(x1, s0))
+
+
+def test_affine_grid_wrong_input_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.affine_grid(x0, size=(16, 16)))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.affine_grid(x1, size=(16, 16)))
+
+
+def test_affine_grid_wrong_size_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 2, 3), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.affine_grid(x0, (16, 16, 16)))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.affine_grid(x0, (16,)))
+
+
[email protected](
+ "batch, target_h, target_w",
+ [
+ (1, 16, 16),
+ (2, 8, 12),
+ (4, 32, 32),
+ ],
+)
+def test_affine_grid_e2e(batch, target_h, target_w):
+ """End-to-end numerical correctness test: build, run, compare with numpy
reference."""
+
+ @tvm.script.ir_module
+ class AffineGridModule:
+ @R.function
+ def main(theta: R.Tensor(("batch", 2, 3), "float32")) ->
R.Tensor("float32", ndim=4):
+ gv = R.image.affine_grid(theta, size=(target_h, target_w))
+ return gv
+
+ target = "llvm"
+ dev = tvm.cpu()
+ exe = tvm.compile(AffineGridModule, target=target)
+ vm = relax.VirtualMachine(exe, dev)
+
+ theta_np = np.random.uniform(-1, 1, size=(batch, 2, 3)).astype("float32")
+ theta_nd = tvm.runtime.tensor(theta_np, dev)
+
+ out_nd = vm["main"](theta_nd)
+ out_np = out_nd.numpy()
+
+ ref_np = tvm.topi.testing.affine_grid_python(theta_np, (target_h,
target_w))
+
+ tvm.testing.assert_allclose(out_np, ref_np, rtol=1e-5, atol=1e-5)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_image.py
b/tests/python/relax/test_transform_legalize_ops_image.py
index 48166d24c4..5c80ce0375 100644
--- a/tests/python/relax/test_transform_legalize_ops_image.py
+++ b/tests/python/relax/test_transform_legalize_ops_image.py
@@ -102,6 +102,42 @@ def test_image_resize2d_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_image_affine_grid():
+ # fmt: off
+ @tvm.script.ir_module
+ class AffineGrid:
+ @R.function
+ def main(theta: R.Tensor((2, 2, 3), "float32")) -> R.Tensor((2, 2, 16,
16), "float32"):
+ gv: R.Tensor((2, 2, 16, 16), "float32") =
R.image.affine_grid(theta, size=(16, 16))
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(theta: R.Tensor((2, 2, 3), "float32")) -> R.Tensor((2, 2, 16,
16), "float32"):
+ gv = R.call_tir(Expected.affine_grid, (theta,), R.Tensor((2, 2,
16, 16), dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def affine_grid(var_theta: T.handle, var_compute: T.handle):
+ T.func_attr({"tirx.noalias": True})
+ theta = T.match_buffer(var_theta, (T.int64(2), T.int64(2),
T.int64(3)))
+ compute = T.match_buffer(var_compute, (T.int64(2), T.int64(2),
T.int64(16), T.int64(16)))
+ with T.sblock("root"):
+ T.reads()
+ T.writes()
+ for n, dim, i, j in T.grid(T.int64(2), T.int64(2),
T.int64(16), T.int64(16)):
+ with T.sblock("compute"):
+ v_n, v_dim, v_i, v_j = T.axis.remap("SSSS", [n, dim,
i, j])
+ T.reads(theta[v_n, v_dim, T.int64(0):T.int64(3)])
+ T.writes(compute[v_n, v_dim, v_i, v_j])
+ compute[v_n, v_dim, v_i, v_j] = theta[v_n, v_dim,
T.int64(0)] * (T.float32(-1.0) + T.Cast("float32", v_j) *
T.float32(0.13333332666666667)) + theta[v_n, v_dim, T.int64(1)] *
(T.float32(-1.0) + T.Cast("float32", v_i) * T.float32(0.13333332666666667)) +
theta[v_n, v_dim, T.int64(2)]
+ # fmt: on
+
+ mod = LegalizeOps()(AffineGrid)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_image_resize3d():
# fmt: off
@tvm.script.ir_module