This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 41a606c726 [Relax][PyTorch] Add support for grid_sample operator
(#18483)
41a606c726 is described below
commit 41a606c726dbbd77a3f7c7daaa1d069f705477bd
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sat Nov 22 01:18:10 2025 +0800
[Relax][PyTorch] Add support for grid_sample operator (#18483)
## Related Issue
closes #18475
## How
- add support for grid_sample operator
---
include/tvm/relax/attrs/image.h | 22 ++++++
.../frontend/torch/exported_program_translator.py | 27 ++++++++
python/tvm/relax/op/image/__init__.py | 2 +-
python/tvm/relax/op/image/image.py | 49 ++++++++++++++
python/tvm/relax/transform/legalize_ops/image.py | 13 ++++
src/relax/op/image/resize.cc | 78 ++++++++++++++++++++++
src/relax/op/image/resize.h | 4 ++
.../relax/test_frontend_from_exported_program.py | 34 ++++++++++
8 files changed, 228 insertions(+), 1 deletion(-)
diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h
index 4d626a022c..b367ce5843 100644
--- a/include/tvm/relax/attrs/image.h
+++ b/include/tvm/relax/attrs/image.h
@@ -78,6 +78,28 @@ struct Resize2DAttrs : public
AttrsNodeReflAdapter<Resize2DAttrs> {
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize2DAttrs",
Resize2DAttrs, BaseAttrsNode);
}; // struct Resize2dAttrs
+/*! \brief Attributes used in image grid_sample operator */
+struct GridSampleAttrs : public AttrsNodeReflAdapter<GridSampleAttrs> {
+ ffi::String method;
+ ffi::String layout;
+ ffi::String padding_mode;
+ bool align_corners;
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<GridSampleAttrs>()
+ .def_ro("method", &GridSampleAttrs::method,
+ "Interpolation method. Can be 'nearest', 'bilinear', or
'bicubic'.")
+ .def_ro("layout", &GridSampleAttrs::layout,
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC',
etc.")
+ .def_ro("padding_mode", &GridSampleAttrs::padding_mode,
+ "Padding mode for outside grid values. Can be 'zeros',
'border', or 'reflection'.")
+ .def_ro("align_corners", &GridSampleAttrs::align_corners,
+ "If True, the corner pixels of the input and output tensors
are aligned.");
+ }
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GridSampleAttrs",
GridSampleAttrs, BaseAttrsNode);
+}; // struct GridSampleAttrs
+
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index e91f006926..64af72c457 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -848,6 +848,32 @@ class ExportedProgramImporter(BaseFXGraphImporter):
)
return self.block_builder.emit(relax.op.zeros(size, dtype))
+ def _grid_sampler_2d(self, node: fx.Node) -> relax.Var:
+ """Convert torch.nn.functional.grid_sample to
relax.op.image.grid_sample."""
+ args = self.retrieve_args(node)
+ data = args[0]
+ grid = args[1]
+ interp_mode = args[2] if len(args) > 2 else 0
+ pad_mode = args[3] if len(args) > 3 else 0
+ align_corners = args[4] if len(args) > 4 else False
+
+ interp_map = {0: "bilinear", 1: "nearest", 2: "bicubic"}
+ pad_map = {0: "zeros", 1: "border", 2: "reflection"}
+
+ method = interp_map.get(interp_mode, "bilinear")
+ padding_mode = pad_map.get(pad_mode, "zeros")
+
+ return self.block_builder.emit(
+ relax.op.image.grid_sample(
+ data,
+ grid,
+ method=method,
+ layout="NCHW",
+ padding_mode=padding_mode,
+ align_corners=align_corners,
+ )
+ )
+
def _scalar_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
scalar_value = args[0]
@@ -1222,6 +1248,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"zero_.default": self._zeros_inplace,
"zeros.default": self._zeros,
"zeros_like.default": self._zeros_like,
+ "grid_sampler_2d.default": self._grid_sampler_2d,
# datatype
"to.dtype": self._to,
"to.dtype_layout": self._to,
diff --git a/python/tvm/relax/op/image/__init__.py
b/python/tvm/relax/op/image/__init__.py
index 10ef635cbf..15c1847b28 100644
--- a/python/tvm/relax/op/image/__init__.py
+++ b/python/tvm/relax/op/image/__init__.py
@@ -15,4 +15,4 @@
# specific language governing permissions and limitations
# under the License.
"""Image operators."""
-from .image import resize2d
+from .image import grid_sample, resize2d
diff --git a/python/tvm/relax/op/image/image.py
b/python/tvm/relax/op/image/image.py
index afadbf35fb..893f7af90f 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -130,3 +130,52 @@ def resize2d(
extrapolation_value,
out_dtype,
)
+
+
+def grid_sample(
+ data: Expr,
+ grid: Expr,
+ method: str = "bilinear",
+ layout: str = "NCHW",
+ padding_mode: str = "zeros",
+ align_corners: bool = False,
+) -> Expr:
+ """Applies grid sampling to input feature map.
+
+ Given data and grid, the output is computed by sampling from data using
+ the grid coordinates.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data tensor with shape [N, C, H, W] for NCHW layout.
+
+ grid : relax.Expr
+ The grid tensor with shape [N, H_out, W_out, 2]. The values are
normalized
+ to [-1, 1], where (-1, -1) is the top-left corner and (1, 1) is the
bottom-right.
+
+ method : str
+ Interpolation method. Can be 'nearest', 'bilinear', or 'bicubic'.
+
+ layout : str
+ Layout of the input data. Default is 'NCHW'.
+
+ padding_mode : str
+ Padding mode for outside grid values. Can be 'zeros', 'border', or
'reflection'.
+
+ align_corners : bool
+ If True, the corner pixels of the input and output tensors are aligned.
+
+ Returns
+ -------
+ result : relax.Expr
+ The sampled output tensor with shape [N, C, H_out, W_out].
+ """
+ return _ffi_api.grid_sample( # type: ignore
+ data,
+ grid,
+ method,
+ layout,
+ padding_mode,
+ align_corners,
+ )
diff --git a/python/tvm/relax/transform/legalize_ops/image.py
b/python/tvm/relax/transform/legalize_ops/image.py
index 1b2a342b0b..7a1c2e92cb 100644
--- a/python/tvm/relax/transform/legalize_ops/image.py
+++ b/python/tvm/relax/transform/legalize_ops/image.py
@@ -37,3 +37,16 @@ def _image_resize2d(bb: BlockBuilder, call: Call) -> Expr:
bicubic_exclude=call.attrs.cubic_exclude,
extrapolation_value=call.attrs.extrapolation_value,
)
+
+
+@register_legalize("relax.image.grid_sample")
+def _image_grid_sample(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ topi.image.grid_sample,
+ call.args[0],
+ call.args[1],
+ method=call.attrs.method,
+ layout=call.attrs.layout,
+ padding_mode=call.attrs.padding_mode,
+ align_corners=call.attrs.align_corners,
+ )
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 8b7b8dd2a5..59d845d867 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -148,5 +148,83 @@ TVM_REGISTER_OP("relax.image.resize2d")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.grid_sample */
+
+TVM_FFI_STATIC_INIT_BLOCK() { GridSampleAttrs::RegisterReflection(); }
+
+Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout,
+ ffi::String padding_mode, bool align_corners) {
+ ObjectPtr<GridSampleAttrs> attrs = ffi::make_object<GridSampleAttrs>();
+ attrs->method = std::move(method);
+ attrs->layout = std::move(layout);
+ attrs->padding_mode = std::move(padding_mode);
+ attrs->align_corners = align_corners;
+
+ static const Op& op = Op::Get("relax.image.grid_sample");
+ return Call(op, {std::move(data), std::move(grid)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.image.grid_sample", grid_sample);
+}
+
+StructInfo InferStructInfoGridSample(const Call& call, const BlockBuilder&
ctx) {
+ if (call->args.size() != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "GridSample expects two arguments, while the given
number of arguments is "
+ << call->args.size());
+ }
+
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* grid_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "GridSample expects the input data to be a Tensor,
while the given data is "
+ << call->args[0]->GetTypeKey());
+ }
+ if (grid_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "GridSample expects the grid to be a Tensor, while the
given grid is "
+ << call->args[1]->GetTypeKey());
+ }
+
+ const auto* attrs = call->attrs.as<GridSampleAttrs>();
+ auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout,
+ /*tgt_layout=*/"NCHW",
+ /*tensor_name=*/"data");
+
+ DataType out_dtype = data_sinfo->dtype;
+
+ // Output shape: [N, C, grid_H, grid_W]
+ // grid shape for NCHW layout input is [N, H_out, W_out, 2]
+ ffi::Optional<ShapeExpr> data_shape = CheckNdimPerLayoutAndGetShape(
+ call, ctx, ffi::GetRef<TensorStructInfo>(data_sinfo), data_layout);
+ const auto* grid_shape = grid_sinfo->shape.as<ShapeExprNode>();
+
+ if (!data_shape.defined() || grid_shape == nullptr) {
+ return TensorStructInfo(out_dtype, data_layout.ndim(),
data_sinfo->vdevice);
+ }
+
+ ffi::Array<PrimExpr> data_NCHW_shape =
data2NCHW.ForwardShape(data_shape.value()->values);
+ // grid is [N, H_out, W_out, 2], output is [N, C, H_out, W_out]
+ ffi::Array<PrimExpr> out_NCHW_shape(data_NCHW_shape);
+ out_NCHW_shape.Set(2, grid_shape->values[1]); // H_out
+ out_NCHW_shape.Set(3, grid_shape->values[2]); // W_out
+
+ ffi::Array<PrimExpr> out_shape = data2NCHW.BackwardShape(out_NCHW_shape);
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype,
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.image.grid_sample")
+ .set_attrs_type<GridSampleAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("grid", "Tensor", "The grid tensor for sampling.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGridSample)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h
index 5125a17804..a208aae092 100644
--- a/src/relax/op/image/resize.h
+++ b/src/relax/op/image/resize.h
@@ -38,6 +38,10 @@ Expr resize2d(Expr data, Expr size, ffi::Array<FloatImm>
roi, ffi::String layout
ffi::String rounding_method, double cubic_alpha, int
cubic_exclude,
double extrapolation_value, ffi::Optional<DataType> out_dtype);
+/*! \brief Image grid_sample operator. */
+Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout,
+ ffi::String padding_mode, bool align_corners);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index ff0f5401ec..a19c36ca22 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7700,5 +7700,39 @@ def test_scatter_value():
verify_model(ScatterValue(), example_args, {}, Expected)
+def test_grid_sample():
+ class GridSample(Module):
+ def forward(self, input, grid):
+ return torch.nn.functional.grid_sample(
+ input, grid, mode="bilinear", padding_mode="zeros",
align_corners=True
+ )
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 4, 4), dtype="float32"),
+ grid: R.Tensor((1, 2, 2, 2), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 2, 2), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 2, 2), dtype="float32") =
R.image.grid_sample(
+ input_1,
+ grid,
+ method="bilinear",
+ layout="NCHW",
+ padding_mode="zeros",
+ align_corners=True,
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 2, 2), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(1, 3, 4, 4, dtype=torch.float32),
+ torch.randn(1, 2, 2, 2, dtype=torch.float32),
+ )
+ verify_model(GridSample(), example_args, {}, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()