This is an automated email from the ASF dual-hosted git repository. guan404ming pushed a commit to branch fix/onnx-affinegrid-3d in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 8f704ef160dddafb68150c3bb332c68ea56afc39 Author: Guan-Ming (Wesley) Chiu <[email protected]> AuthorDate: Fri Jun 19 23:42:10 2026 +0800 [Relax][ONNX] Support 3D AffineGrid --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 20 +++++++------ python/tvm/relax/op/image/image.py | 15 ++++++---- python/tvm/topi/image/grid_sample.py | 37 ++++++++++++++---------- src/relax/op/image/resize.cc | 38 ++++++++++++++----------- tests/python/relax/test_frontend_onnx.py | 26 +++++++++++++++++ 5 files changed, 89 insertions(+), 47 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index d550d3bc00..0bd6627a8c 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3322,15 +3322,17 @@ class AffineGrid(OnnxOpConverter): else: raise NotImplementedError(f"Dynamic size of type {type(size)} is not supported") - # Only 2D is supported: size = [N, C, H, W] - if len(size_vals) != 4: - raise ValueError("Only 2D AffineGrid (size=[N,C,H,W]) is supported") - target_h, target_w = size_vals[2], size_vals[3] - - # Relax affine_grid outputs [N, 2, H, W] - grid = bb.emit(relax.op.image.affine_grid(theta, (target_h, target_w))) - # Permute to ONNX convention [N, H, W, 2] - return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1])) + if len(size_vals) == 4: + # 2D: size = [N, C, H, W]; relax affine_grid outputs [N, 2, H, W]. + grid = bb.emit(relax.op.image.affine_grid(theta, tuple(size_vals[2:]))) + # Permute to ONNX convention [N, H, W, 2]. + return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1])) + if len(size_vals) == 5: + # 3D: size = [N, C, D, H, W]; relax affine_grid outputs [N, 3, D, H, W]. + grid = bb.emit(relax.op.image.affine_grid(theta, tuple(size_vals[2:]))) + # Permute to ONNX convention [N, D, H, W, 3]. + return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 4, 1])) + raise ValueError("AffineGrid expects size to be [N,C,H,W] (2D) or [N,C,D,H,W] (3D)") class Einsum(OnnxOpConverter): diff --git a/python/tvm/relax/op/image/image.py b/python/tvm/relax/op/image/image.py index 323bfa74b5..09837bfbb2 100644 --- a/python/tvm/relax/op/image/image.py +++ b/python/tvm/relax/op/image/image.py @@ -238,7 +238,7 @@ def affine_grid( data: Expr, size: Expr | SizeLike, ) -> Expr: - """Generate a 2D sampling grid using an affine transformation matrix. + """Generate a 2D or 3D sampling grid using an affine transformation matrix. This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform sampling grid within the target shape, normalizes it @@ -247,16 +247,19 @@ def affine_grid( Parameters ---------- data : relax.Expr - The input affine matrix tensor with shape [batch, 2, 3]. + The input affine matrix tensor with shape [batch, 2, 3] for 2D or + [batch, 3, 4] for 3D. - size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, PrimExprLike]] - The target output spatial shape (H, W). If a single integer or PrimExpr - is provided, it is interpreted as a square output shape (size, size). + size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, ...]] + The target output spatial shape, (H, W) for 2D or (D, H, W) for 3D. If a + single integer or PrimExpr is provided, it is interpreted as a square 2D + output shape (size, size). Returns ------- result : relax.Expr - The output grid tensor with shape [batch, 2, H, W]. + The output grid tensor with shape [batch, 2, H, W] for 2D or + [batch, 3, D, H, W] for 3D. Note ---- diff --git a/python/tvm/topi/image/grid_sample.py b/python/tvm/topi/image/grid_sample.py index 79032f41b3..15f6178cd6 100644 --- a/python/tvm/topi/image/grid_sample.py +++ b/python/tvm/topi/image/grid_sample.py @@ -21,7 +21,7 @@ from tvm import te, tirx def affine_grid(data, target_shape): - """affine_grid operator that generates 2D sampling grid. + """affine_grid operator that generates a 2D or 3D sampling grid. This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine @@ -30,31 +30,38 @@ def affine_grid(data, target_shape): Parameters ---------- data : tvm.Tensor - 3-D with shape [batch, 2, 3]. The affine matrix. + 3-D with shape [batch, 2, 3] for 2D or [batch, 3, 4] for 3D. The affine matrix. - target_shape: list/tuple of two int - Specifies the output shape (H, W). + target_shape: list/tuple of int + Specifies the output spatial shape (H, W) for 2D or (D, H, W) for 3D. Returns ------- Output : tvm.Tensor - 4-D with shape [batch, 2, target_height, target_width] + [batch, 2, H, W] for 2D or [batch, 3, D, H, W] for 3D. """ assert target_shape is not None - assert len(target_shape) == 2 - assert target_shape[0] > 1 and target_shape[1] > 1, ( - "target height/width should be greater than 1" - ) + assert len(target_shape) in (2, 3) + assert all(s > 1 for s in target_shape), "target spatial dims should be greater than 1" dtype = data.dtype - y_step = tirx.const((2.0 - 1e-7) / (target_shape[0] - 1), dtype=dtype) - x_step = tirx.const((2.0 - 1e-7) / (target_shape[1] - 1), dtype=dtype) start = tirx.const(-1.0, dtype=dtype) + steps = [tirx.const((2.0 - 1e-7) / (s - 1), dtype=dtype) for s in target_shape] + + if len(target_shape) == 2: + + def _compute(n, dim, i, j): + y = start + i * steps[0] + x = start + j * steps[1] + return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2] + + else: - def _compute(n, dim, i, j): - y = start + i * y_step - x = start + j * x_step - return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2] + def _compute(n, dim, k, i, j): + z = start + k * steps[0] + y = start + i * steps[1] + x = start + j * steps[2] + return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2] * z + data[n, dim, 3] oshape = (data.shape[0], len(target_shape), *target_shape) return te.compute(oshape, _compute, tag="affine_grid") diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index beb89af087..86a8dc0095 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -385,35 +385,38 @@ Type InferTypeAffineGrid(const Call& call, const BlockBuilder& ctx) { << "AffineGrid expects the target size to be a Shape, while the given one is " << call->args[1]->GetTypeKey(); } - if (size_ty->ndim != 2) { + // 2D output uses a 2-dim spatial size ([H, W]); 3D uses a 3-dim size ([D, H, W]). + if (size_ty->ndim != 2 && size_ty->ndim != 3) { TVM_FFI_VISIT_THROW(ValueError, call) - << "AffineGrid expects the target size to be a 2-dim shape, while the given " + << "AffineGrid expects the target size to be a 2-dim or 3-dim shape, while the given " "one has ndim " << size_ty->ndim; } + const int spatial = size_ty->ndim; - // data should be 3-D: [batch, 2, 3] + // data should be 3-D: [batch, spatial, spatial + 1] (i.e. [N, 2, 3] or [N, 3, 4]). if (data_ty->ndim != -1 && data_ty->ndim != 3) { - TVM_FFI_VISIT_THROW(ValueError, call) - << "AffineGrid expects the input data to be 3-D (batch, 2, 3), but got ndim " - << data_ty->ndim; + TVM_FFI_VISIT_THROW(ValueError, call) << "AffineGrid expects the input data to be 3-D (batch, " + "spatial, spatial + 1), but got ndim " + << data_ty->ndim; } const auto* data_shape = data_ty->shape.as<ShapeExprNode>(); if (data_shape != nullptr) { - // Check that the affine matrix has shape [batch, 2, 3] if (data_shape->values.size() >= 2) { auto* dim1 = data_shape->values[1].as<IntImmNode>(); - if (dim1 != nullptr && dim1->value != 2) { + if (dim1 != nullptr && dim1->value != spatial) { TVM_FFI_VISIT_THROW(ValueError, call) - << "AffineGrid expects the second dimension of input to be 2, but got " << dim1->value; + << "AffineGrid expects the second dimension of input to be " << spatial << ", but got " + << dim1->value; } } if (data_shape->values.size() >= 3) { auto* dim2 = data_shape->values[2].as<IntImmNode>(); - if (dim2 != nullptr && dim2->value != 3) { + if (dim2 != nullptr && dim2->value != spatial + 1) { TVM_FFI_VISIT_THROW(ValueError, call) - << "AffineGrid expects the third dimension of input to be 3, but got " << dim2->value; + << "AffineGrid expects the third dimension of input to be " << spatial + 1 + << ", but got " << dim2->value; } } } @@ -421,15 +424,16 @@ Type InferTypeAffineGrid(const Call& call, const BlockBuilder& ctx) { DataType out_dtype = data_ty->dtype; if (data_shape == nullptr || size_value == nullptr) { - return TensorType(out_dtype, /*ndim=*/4, data_ty->vdevice); + return TensorType(out_dtype, /*ndim=*/spatial + 2, data_ty->vdevice); } - // Output shape: [batch, 2, target_height, target_width] + // Output shape: [batch, spatial, *target_spatial_dims]. ffi::Array<PrimExpr> out_shape; - out_shape.push_back(data_shape->values[0]); // batch - out_shape.push_back(IntImm::Int64(2)); // 2 (spatial dimensions) - out_shape.push_back(size_value->values[0]); // target_height - out_shape.push_back(size_value->values[1]); // target_width + out_shape.push_back(data_shape->values[0]); // batch + out_shape.push_back(IntImm::Int64(spatial)); // number of spatial coordinates + for (int i = 0; i < spatial; ++i) { + out_shape.push_back(size_value->values[i]); // target spatial dim + } return TensorType(ShapeExpr(out_shape), out_dtype, data_ty->vdevice); } diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index d3036a2547..b47bd1c43b 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -5581,6 +5581,32 @@ def test_affine_grid(): check_correctness(model, opset=20) +def test_affine_grid_3d(): + affine_grid_node = helper.make_node( + "AffineGrid", + ["theta", "size"], + ["grid"], + align_corners=1, + ) + + graph = helper.make_graph( + [affine_grid_node], + "affine_grid_3d_test", + inputs=[ + helper.make_tensor_value_info("theta", TensorProto.FLOAT, [2, 3, 4]), + ], + initializer=[ + helper.make_tensor("size", TensorProto.INT64, [5], [2, 3, 8, 16, 16]), + ], + outputs=[ + helper.make_tensor_value_info("grid", TensorProto.FLOAT, [2, 8, 16, 16, 3]), + ], + ) + + model = helper.make_model(graph, producer_name="affine_grid_3d_test") + check_correctness(model, opset=20) + + @pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"]) @pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"]) @pytest.mark.parametrize("align_corners", [0, 1])
