This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fa26a05162 [Relax][PyTorch] Add Meshgrid Op Support for Exported 
Program and FX graph (#17904)
fa26a05162 is described below

commit fa26a05162c933729cb99eaaaf0e716fee553502
Author: Deivanayaki S <[email protected]>
AuthorDate: Tue May 6 21:24:06 2025 +0530

    [Relax][PyTorch] Add Meshgrid Op Support for Exported Program and FX graph 
(#17904)
    
    * add torch.meshgrid op support into torch frontends
    
    * remove trailing whitespaces
    
    * fix lint issues
    
    * fix space issue in test script
    
    * fix func definition issue
    
    * set relax var shape to fix the unity issue
    
    * fix format issue in input declaration
    
    * fix lint issue
    
    * fix cpp lints
    
    * ix cpp lint issue in manipulate file
    
    * fix wrong input in struct info test script
    
    * add one more mapping for meshgrid in exported program
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki.>
---
 include/tvm/relax/attrs/manipulate.h               |   9 ++
 .../frontend/torch/base_fx_graph_translator.py     |  20 ++++
 .../frontend/torch/exported_program_translator.py  |   2 +
 python/tvm/relax/frontend/torch/fx_translator.py   |   1 +
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/manipulate.py                  |  23 +++++
 .../tvm/relax/transform/legalize_ops/manipulate.py |  19 ++++
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 src/relax/op/tensor/manipulate.cc                  | 103 +++++++++++++++++++++
 src/relax/op/tensor/manipulate.h                   |   8 ++
 .../relax/test_frontend_from_exported_program.py   |  53 +++++++++++
 tests/python/relax/test_frontend_from_fx.py        |  60 ++++++++++++
 tests/python/relax/test_op_manipulate.py           |  45 +++++++++
 13 files changed, 346 insertions(+)

diff --git a/include/tvm/relax/attrs/manipulate.h 
b/include/tvm/relax/attrs/manipulate.h
index 943d2f4d0d..2993223079 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -196,6 +196,15 @@ struct IndexPutAttrs : public 
tvm::AttrsNode<IndexPutAttrs> {
   }
 };  // struct IndexPutAttrs
 
+/*! \brief Attribute used in meshgrid operator */
+struct MeshgridAttrs : public tvm::AttrsNode<MeshgridAttrs> {
+  Optional<String> indexing;
+
+  TVM_DECLARE_ATTRS(MeshgridAttrs, "relax.attrs.MeshgridAttrs") {
+    TVM_ATTR_FIELD(indexing).describe("Specifies how the grid dimensions are 
ordered.");
+  }
+};
+
 /*! \brief Attributes used in scatter_elements operators */
 struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
   Integer axis;
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 48869767ad..0b48a015e1 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1179,6 +1179,26 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         indices = args[1]
         return self.block_builder.emit(relax.op.index_tensor(args[0], indices))
 
+    def _meshgrid(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        indexing = args[1] if len(node.args) > 1 else 
node.kwargs.get("indexing", "ij")
+        input_list = args[0]
+
+        # Single input: return as-is, meshgrid not applicable.
+        if len(input_list) == 1:
+            return input_list
+        new_inputs = []
+        for i, item in enumerate(input_list):
+            if item.struct_info.ndim == 1:
+                new_inputs.append(item)
+            elif item.struct_info.ndim == 0:  # Change scalar value into 1D
+                const_tensor = relax.op.reshape(item, (1,))
+                new_inputs.append(const_tensor)
+            else:
+                raise TypeError(f"Unsupported meshgrid input type at index 
{i}: {type(item)}")
+
+        return self.block_builder.emit(relax.op.meshgrid(new_inputs, 
indexing=indexing))
+
     def _permute(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index df532fd1ea..8b584906c8 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -439,6 +439,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "gather.default": self._gather,
             "index.Tensor": self._index_tensor,
             "index_put_.default": self._index_put,
+            "meshgrid.indexing": self._meshgrid,
+            "meshgrid.default": self._meshgrid,
             "narrow.default": self._narrow,
             "permute.default": self._permute,
             "repeat.default": self._repeat,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 5f65f86a43..f1223b6243 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -804,6 +804,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "flip": self._flip,
             "gather": self._gather,
             "index_put_": self._index_put,
+            "meshgrid": self._meshgrid,
             "narrow": self._narrow,
             "numel": self._numel,
             "permute": self._permute,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 7b8c34b641..be5306c9f4 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -97,6 +97,7 @@ from .manipulate import (
     gather_nd,
     index_put,
     index_tensor,
+    meshgrid,
     layout_transform,
     one_hot,
     permute_dims,
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index 13334d1479..b52aced59a 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -646,6 +646,29 @@ def index_put(
     return _ffi_api.index_put(data, indices, values, accumulate)  # type: 
ignore
 
 
+def meshgrid(tensors: Union[Expr, List[Expr]], indexing: Optional[str] = "ij") 
-> Expr:
+    """Generate coordinate grids from input tensors.
+
+    Parameters
+    ----------
+    tensors : Union[relax.Expr, List[relax.Expr]]
+        An Expr in Tuple type, containing 1D tensors (or scalars promoted to 
1D)
+        to generate coordinate grids from, or a list of such tensors.
+
+    indexing : Optional[str]
+        The indexing mode, either "ij" (matrix indexing) or "xy" (Cartesian 
indexing).
+        Defaults to "ij".
+
+    Returns
+    -------
+    result : relax.Expr
+        A Tuple of tensors representing the coordinate grids.
+    """
+    if isinstance(tensors, (list, tuple)):
+        tensors = RxTuple(tensors)
+    return _ffi_api.meshgrid(tensors, indexing)
+
+
 def scatter_elements(
     data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = 
"update"
 ):
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index a66b60c013..04b16d4db3 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -215,6 +215,25 @@ def _index_put(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.meshgrid")
+def _meshgrid(bb: BlockBuilder, call: Call) -> Expr:
+    t = call.args[0]
+    n_field = len(t.struct_info.fields)
+    while isinstance(t, Var):
+        binding = bb.lookup_binding(t)
+        if not isinstance(binding, (Tuple, Var)):
+            break
+        t = binding
+
+    assert isinstance(t, (Tuple, Var))
+    fields = (
+        t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for 
i in range(n_field)]
+    )
+    return bb.call_te(
+        topi.meshgrid, fields, "ij" if call.attrs.indexing is None else 
call.attrs.indexing
+    )
+
+
 @register_legalize("relax.scatter_elements")
 def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index d2952ed8e0..6d5dbc20cb 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -125,6 +125,7 @@ from tvm.relax.op import (
     maximum,
     mean,
     memory,
+    meshgrid,
     min,
     minimum,
     mod,
@@ -811,6 +812,7 @@ __all__ = [
     "maximum",
     "mean",
     "memory",
+    "meshgrid",
     "metal",
     "min",
     "minimum",
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 482ebe5cac..dd46f23974 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2095,6 +2095,109 @@ TVM_REGISTER_OP("relax.index_put")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoIndexPut)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.meshgrid */
+TVM_REGISTER_NODE_TYPE(MeshgridAttrs);
+
+Expr meshgrid(Expr tensors, Optional<String> indexing) {
+  ObjectPtr<MeshgridAttrs> attrs = make_object<MeshgridAttrs>();
+  attrs->indexing = indexing;
+  static const Op& op = Op::Get("relax.meshgrid");
+  return Call(op, {std::move(tensors)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.meshgrid").set_body_typed(meshgrid);
+
+StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "meshgrid op expects 1 Tuple 
input argument.");
+  }
+  Array<TensorStructInfo> input_sinfo = GetTensorStructInfoFromTuple(call, 
ctx, call->args[0]);
+
+  int n_inputs = input_sinfo.size();
+
+  if (n_inputs == 0) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "meshgrid expects at least one 1D tensor in the input 
Tuple.");
+  }
+
+  std::vector<PrimExpr> lengths;
+  DataType common_dtype = DataType::Void();
+  bool shape_unknown = false;
+  Optional<VDevice> vdev = NullOpt;
+  bool vdevice_unknown = false;
+
+  for (int i = 0; i < n_inputs; ++i) {
+    const TensorStructInfo& sinfo = input_sinfo[i];
+
+    if (sinfo->ndim != 1) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "meshgrid expects each input tensor to be 1D. Got 
ndim = " << sinfo->ndim
+                       << " at index " << i);
+    }
+
+    if (sinfo->dtype.is_void()) {
+      continue;
+    } else if (common_dtype.is_void()) {
+      common_dtype = sinfo->dtype;
+    } else if (sinfo->dtype != common_dtype) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "meshgrid expects all input tensors to have the same 
dtype. Found "
+                       << sinfo->dtype << " and " << common_dtype);
+    }
+
+    const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
+    if (shape_expr && shape_expr->values.size() == 1) {
+      lengths.push_back(shape_expr->values[0]);
+    } else {
+      shape_unknown = true;
+    }
+
+    if (!vdevice_unknown) {
+      if (sinfo->vdevice.defined()) {
+        if (!vdev.defined()) {
+          vdev = sinfo->vdevice.value();
+        } else if (sinfo->vdevice.value() != vdev) {
+          vdevice_unknown = true;
+        }
+      }
+    }
+  }
+
+  Array<PrimExpr> out_shape;
+  if (!shape_unknown && lengths.size() == static_cast<size_t>(n_inputs)) {
+    for (const PrimExpr& dim : lengths) {
+      out_shape.push_back(dim);
+    }
+  }
+
+  Array<StructInfo> out_fields;
+  for (int i = 0; i < n_inputs; ++i) {
+    if (!out_shape.empty()) {
+      if (!vdevice_unknown) {
+        out_fields.push_back(TensorStructInfo(ShapeExpr(out_shape), 
common_dtype, vdev));
+      } else {
+        out_fields.push_back(TensorStructInfo(ShapeExpr(out_shape), 
common_dtype));
+      }
+    } else {
+      if (!vdevice_unknown) {
+        out_fields.push_back(TensorStructInfo(common_dtype, n_inputs, vdev));
+      } else {
+        out_fields.push_back(TensorStructInfo(common_dtype, n_inputs));
+      }
+    }
+  }
+
+  return TupleStructInfo(out_fields);
+}
+
+TVM_REGISTER_OP("relax.meshgrid")
+    .set_attrs_type<MeshgridAttrs>()
+    .set_num_inputs(1)
+    .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMeshgrid)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.scatter_elements */
 TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);
 
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 2e4c92c150..12d70da72a 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -231,6 +231,14 @@ Expr index_tensor(Expr data, Expr indices);
  */
 Expr index_put(Expr data, Expr indices, Expr values, bool accumulate = false);
 
+/*!
+ * \brief Generate coordinate grids from input 1D tensors.
+ * \param tensors A tuple of 1D tensors representing coordinate vectors.
+ * \param indexing Indexing mode, either "ij" (matrix indexing) or "xy" 
(Cartesian indexing).
+ * \return A tuple of tensors representing the coordinate grids.
+ */
+Expr meshgrid(Expr tensors, Optional<String> indexing = String("ij"));
+
 /*!
  * \brief Scatter updates into an array according to indices.
  * \param data The input tensor.
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index f0bb33964e..b07070ddc9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2920,6 +2920,59 @@ def test_flatten():
     verify_model(Flatten(), example_args, {}, expected1)
 
 
+def test_meshgrid():
+    class Meshgrid1(Module):
+        def forward(self, input1, input2):
+            return torch.meshgrid((input1, input2), indexing="ij")
+
+    class Meshgrid2(Module):
+        def forward(self, input1, input2):
+            return torch.meshgrid((input1, input2), indexing="xy")
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
+                ) = R.meshgrid((input1, input2), indexing="ij")
+                lv1: R.Tensor((3, 3), dtype="float32") = lv[0]
+                lv2: R.Tensor((3, 3), dtype="float32") = lv[1]
+                gv: R.Tuple(
+                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
+                ) = (lv1, lv2)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
+                ) = R.meshgrid((input1, input2), indexing="xy")
+                lv1: R.Tensor((3, 3), dtype="float32") = lv[0]
+                lv2: R.Tensor((3, 3), dtype="float32") = lv[1]
+                gv: R.Tuple(
+                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
+                ) = (lv1, lv2)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(3, dtype=torch.float32),
+        torch.randn(3, dtype=torch.float32),
+    )
+    verify_model(Meshgrid1(), example_args, {}, expected1)
+    verify_model(Meshgrid2(), example_args, {}, expected2)
+
+
 def test_permute():
     class Permute1(Module):
         def forward(self, x):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 490a2309aa..2bb2a84441 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3528,6 +3528,66 @@ def test_datatype():
     verify_model(AsType(), input_info, {}, expected1)
 
 
+def test_meshgrid():
+    input_infos = [
+        (
+            [
+                3,
+            ],
+            "float32",
+        ),
+        (
+            [
+                3,
+            ],
+            "float32",
+        ),
+    ]
+
+    class Meshgrid1(Module):
+        def forward(self, input1, input2):
+            return torch.meshgrid((input1, input2), indexing="ij")
+
+    class Meshgrid2(Module):
+        def forward(self, input1, input2):
+            return torch.meshgrid((input1, input2), indexing="xy")
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3,), dtype="float32"), inp_1: R.Tensor((3,), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
+                ) = R.meshgrid((inp_0, inp_1), indexing="ij")
+                gv: R.Tuple(
+                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
+                ) = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3,), dtype="float32"), inp_1: R.Tensor((3,), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
+                ) = R.meshgrid((inp_0, inp_1), indexing="xy")
+                gv: R.Tuple(
+                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
+                ) = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Meshgrid1(), input_infos, {}, expected1)
+    verify_model(Meshgrid2(), input_infos, {}, expected2)
+
+
 def test_permute():
     input_info = [([1, 2, 3, 4], "float32")]
 
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index 28e762d9a4..2e171a0a5b 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -3502,6 +3502,51 @@ def test_scatter_nd_infer_struct_info():
     )
 
 
+def test_meshgrid_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    t0 = relax.Var("t0", R.Tensor((3,), "float32"))
+    t1 = relax.Var("t1", R.Tensor((4,), "float32"))
+    t2 = relax.Var("t2", R.Tensor("float32", ndim=1))
+    t3 = relax.Var("t3", R.Tensor((5,), "float32", vdev0))
+
+    _check_inference(
+        bb,
+        relax.op.meshgrid((t0, t1), indexing="ij"),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo((3, 4), "float32"), 
relax.TensorStructInfo((3, 4), "float32")]
+        ),
+    )
+
+    _check_inference(
+        bb,
+        relax.op.meshgrid((t3, t1), indexing="ij"),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((5, 4), "float32", vdev0),
+                relax.TensorStructInfo((5, 4), "float32", vdev0),
+            ]
+        ),
+    )
+
+    _check_inference(
+        bb,
+        relax.op.meshgrid((t2, t1), indexing="xy"),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=2),
+                relax.TensorStructInfo(dtype="float32", ndim=2),
+            ]
+        ),
+    )
+
+    _check_inference(
+        bb,
+        relax.op.meshgrid((t0,), indexing="ij"),
+        relax.TupleStructInfo([relax.TensorStructInfo((3,), "float32")]),
+    )
+
+
 def test_one_hot_infer_struct_info():
     bb = relax.BlockBuilder()
 

Reply via email to