This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 41a606c726 [Relax][PyTorch] Add support for grid_sample operator 
(#18483)
41a606c726 is described below

commit 41a606c726dbbd77a3f7c7daaa1d069f705477bd
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sat Nov 22 01:18:10 2025 +0800

    [Relax][PyTorch] Add support for grid_sample operator (#18483)
    
    ## Related Issue
    
    closes #18475
    
    ## How
    
    - add support for grid_sample operator
---
 include/tvm/relax/attrs/image.h                    | 22 ++++++
 .../frontend/torch/exported_program_translator.py  | 27 ++++++++
 python/tvm/relax/op/image/__init__.py              |  2 +-
 python/tvm/relax/op/image/image.py                 | 49 ++++++++++++++
 python/tvm/relax/transform/legalize_ops/image.py   | 13 ++++
 src/relax/op/image/resize.cc                       | 78 ++++++++++++++++++++++
 src/relax/op/image/resize.h                        |  4 ++
 .../relax/test_frontend_from_exported_program.py   | 34 ++++++++++
 8 files changed, 228 insertions(+), 1 deletion(-)

diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h
index 4d626a022c..b367ce5843 100644
--- a/include/tvm/relax/attrs/image.h
+++ b/include/tvm/relax/attrs/image.h
@@ -78,6 +78,28 @@ struct Resize2DAttrs : public 
AttrsNodeReflAdapter<Resize2DAttrs> {
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize2DAttrs", 
Resize2DAttrs, BaseAttrsNode);
 };  // struct Resize2dAttrs
 
+/*! \brief Attributes used in image grid_sample operator */
+struct GridSampleAttrs : public AttrsNodeReflAdapter<GridSampleAttrs> {
+  ffi::String method;
+  ffi::String layout;
+  ffi::String padding_mode;
+  bool align_corners;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<GridSampleAttrs>()
+        .def_ro("method", &GridSampleAttrs::method,
+                "Interpolation method. Can be 'nearest', 'bilinear', or 
'bicubic'.")
+        .def_ro("layout", &GridSampleAttrs::layout,
+                "Dimension ordering of input data. Can be 'NCHW', 'NHWC', 
etc.")
+        .def_ro("padding_mode", &GridSampleAttrs::padding_mode,
+                "Padding mode for outside grid values. Can be 'zeros', 
'border', or 'reflection'.")
+        .def_ro("align_corners", &GridSampleAttrs::align_corners,
+                "If True, the corner pixels of the input and output tensors 
are aligned.");
+  }
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GridSampleAttrs", 
GridSampleAttrs, BaseAttrsNode);
+};  // struct GridSampleAttrs
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index e91f006926..64af72c457 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -848,6 +848,32 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         )
         return self.block_builder.emit(relax.op.zeros(size, dtype))
 
+    def _grid_sampler_2d(self, node: fx.Node) -> relax.Var:
+        """Convert torch.nn.functional.grid_sample to 
relax.op.image.grid_sample."""
+        args = self.retrieve_args(node)
+        data = args[0]
+        grid = args[1]
+        interp_mode = args[2] if len(args) > 2 else 0
+        pad_mode = args[3] if len(args) > 3 else 0
+        align_corners = args[4] if len(args) > 4 else False
+
+        interp_map = {0: "bilinear", 1: "nearest", 2: "bicubic"}
+        pad_map = {0: "zeros", 1: "border", 2: "reflection"}
+
+        method = interp_map.get(interp_mode, "bilinear")
+        padding_mode = pad_map.get(pad_mode, "zeros")
+
+        return self.block_builder.emit(
+            relax.op.image.grid_sample(
+                data,
+                grid,
+                method=method,
+                layout="NCHW",
+                padding_mode=padding_mode,
+                align_corners=align_corners,
+            )
+        )
+
     def _scalar_tensor(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         scalar_value = args[0]
@@ -1222,6 +1248,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "zero_.default": self._zeros_inplace,
             "zeros.default": self._zeros,
             "zeros_like.default": self._zeros_like,
+            "grid_sampler_2d.default": self._grid_sampler_2d,
             # datatype
             "to.dtype": self._to,
             "to.dtype_layout": self._to,
diff --git a/python/tvm/relax/op/image/__init__.py 
b/python/tvm/relax/op/image/__init__.py
index 10ef635cbf..15c1847b28 100644
--- a/python/tvm/relax/op/image/__init__.py
+++ b/python/tvm/relax/op/image/__init__.py
@@ -15,4 +15,4 @@
 # specific language governing permissions and limitations
 # under the License.
 """Image operators."""
-from .image import resize2d
+from .image import grid_sample, resize2d
diff --git a/python/tvm/relax/op/image/image.py 
b/python/tvm/relax/op/image/image.py
index afadbf35fb..893f7af90f 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -130,3 +130,52 @@ def resize2d(
         extrapolation_value,
         out_dtype,
     )
+
+
+def grid_sample(
+    data: Expr,
+    grid: Expr,
+    method: str = "bilinear",
+    layout: str = "NCHW",
+    padding_mode: str = "zeros",
+    align_corners: bool = False,
+) -> Expr:
+    """Applies grid sampling to input feature map.
+
+    Given data and grid, the output is computed by sampling from data using
+    the grid coordinates.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data tensor with shape [N, C, H, W] for NCHW layout.
+
+    grid : relax.Expr
+        The grid tensor with shape [N, H_out, W_out, 2]. The values are 
normalized
+        to [-1, 1], where (-1, -1) is the top-left corner and (1, 1) is the 
bottom-right.
+
+    method : str
+        Interpolation method. Can be 'nearest', 'bilinear', or 'bicubic'.
+
+    layout : str
+        Layout of the input data. Default is 'NCHW'.
+
+    padding_mode : str
+        Padding mode for outside grid values. Can be 'zeros', 'border', or 
'reflection'.
+
+    align_corners : bool
+        If True, the corner pixels of the input and output tensors are aligned.
+
+    Returns
+    -------
+    result : relax.Expr
+        The sampled output tensor with shape [N, C, H_out, W_out].
+    """
+    return _ffi_api.grid_sample(  # type: ignore
+        data,
+        grid,
+        method,
+        layout,
+        padding_mode,
+        align_corners,
+    )
diff --git a/python/tvm/relax/transform/legalize_ops/image.py 
b/python/tvm/relax/transform/legalize_ops/image.py
index 1b2a342b0b..7a1c2e92cb 100644
--- a/python/tvm/relax/transform/legalize_ops/image.py
+++ b/python/tvm/relax/transform/legalize_ops/image.py
@@ -37,3 +37,16 @@ def _image_resize2d(bb: BlockBuilder, call: Call) -> Expr:
         bicubic_exclude=call.attrs.cubic_exclude,
         extrapolation_value=call.attrs.extrapolation_value,
     )
+
+
+@register_legalize("relax.image.grid_sample")
+def _image_grid_sample(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        topi.image.grid_sample,
+        call.args[0],
+        call.args[1],
+        method=call.attrs.method,
+        layout=call.attrs.layout,
+        padding_mode=call.attrs.padding_mode,
+        align_corners=call.attrs.align_corners,
+    )
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 8b7b8dd2a5..59d845d867 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -148,5 +148,83 @@ TVM_REGISTER_OP("relax.image.resize2d")
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.grid_sample */
+
+TVM_FFI_STATIC_INIT_BLOCK() { GridSampleAttrs::RegisterReflection(); }
+
+Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout,
+                 ffi::String padding_mode, bool align_corners) {
+  ObjectPtr<GridSampleAttrs> attrs = ffi::make_object<GridSampleAttrs>();
+  attrs->method = std::move(method);
+  attrs->layout = std::move(layout);
+  attrs->padding_mode = std::move(padding_mode);
+  attrs->align_corners = align_corners;
+
+  static const Op& op = Op::Get("relax.image.grid_sample");
+  return Call(op, {std::move(data), std::move(grid)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.image.grid_sample", grid_sample);
+}
+
+StructInfo InferStructInfoGridSample(const Call& call, const BlockBuilder& 
ctx) {
+  if (call->args.size() != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "GridSample expects two arguments, while the given 
number of arguments is "
+                     << call->args.size());
+  }
+
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* grid_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "GridSample expects the input data to be a Tensor, 
while the given data is "
+                     << call->args[0]->GetTypeKey());
+  }
+  if (grid_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "GridSample expects the grid to be a Tensor, while the 
given grid is "
+                     << call->args[1]->GetTypeKey());
+  }
+
+  const auto* attrs = call->attrs.as<GridSampleAttrs>();
+  auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout,
+                                                    /*tgt_layout=*/"NCHW",
+                                                    /*tensor_name=*/"data");
+
+  DataType out_dtype = data_sinfo->dtype;
+
+  // Output shape: [N, C, grid_H, grid_W]
+  // grid shape for NCHW layout input is [N, H_out, W_out, 2]
+  ffi::Optional<ShapeExpr> data_shape = CheckNdimPerLayoutAndGetShape(
+      call, ctx, ffi::GetRef<TensorStructInfo>(data_sinfo), data_layout);
+  const auto* grid_shape = grid_sinfo->shape.as<ShapeExprNode>();
+
+  if (!data_shape.defined() || grid_shape == nullptr) {
+    return TensorStructInfo(out_dtype, data_layout.ndim(), 
data_sinfo->vdevice);
+  }
+
+  ffi::Array<PrimExpr> data_NCHW_shape = 
data2NCHW.ForwardShape(data_shape.value()->values);
+  // grid is [N, H_out, W_out, 2], output is [N, C, H_out, W_out]
+  ffi::Array<PrimExpr> out_NCHW_shape(data_NCHW_shape);
+  out_NCHW_shape.Set(2, grid_shape->values[1]);  // H_out
+  out_NCHW_shape.Set(3, grid_shape->values[2]);  // W_out
+
+  ffi::Array<PrimExpr> out_shape = data2NCHW.BackwardShape(out_NCHW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), out_dtype, 
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.image.grid_sample")
+    .set_attrs_type<GridSampleAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("grid", "Tensor", "The grid tensor for sampling.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGridSample)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h
index 5125a17804..a208aae092 100644
--- a/src/relax/op/image/resize.h
+++ b/src/relax/op/image/resize.h
@@ -38,6 +38,10 @@ Expr resize2d(Expr data, Expr size, ffi::Array<FloatImm> 
roi, ffi::String layout
               ffi::String rounding_method, double cubic_alpha, int 
cubic_exclude,
               double extrapolation_value, ffi::Optional<DataType> out_dtype);
 
+/*! \brief Image grid_sample operator. */
+Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout,
+                 ffi::String padding_mode, bool align_corners);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index ff0f5401ec..a19c36ca22 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7700,5 +7700,39 @@ def test_scatter_value():
     verify_model(ScatterValue(), example_args, {}, Expected)
 
 
+def test_grid_sample():
+    class GridSample(Module):
+        def forward(self, input, grid):
+            return torch.nn.functional.grid_sample(
+                input, grid, mode="bilinear", padding_mode="zeros", 
align_corners=True
+            )
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 4, 4), dtype="float32"),
+            grid: R.Tensor((1, 2, 2, 2), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 2, 2), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 2, 2), dtype="float32") = 
R.image.grid_sample(
+                    input_1,
+                    grid,
+                    method="bilinear",
+                    layout="NCHW",
+                    padding_mode="zeros",
+                    align_corners=True,
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 2, 2), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(1, 3, 4, 4, dtype=torch.float32),
+        torch.randn(1, 2, 2, 2, dtype=torch.float32),
+    )
+    verify_model(GridSample(), example_args, {}, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to