This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fa26a05162 [Relax][PyTorch] Add Meshgrid Op Support for Exported
Program and FX graph (#17904)
fa26a05162 is described below
commit fa26a05162c933729cb99eaaaf0e716fee553502
Author: Deivanayaki S <[email protected]>
AuthorDate: Tue May 6 21:24:06 2025 +0530
[Relax][PyTorch] Add Meshgrid Op Support for Exported Program and FX graph
(#17904)
* add torch.meshgrid op support into torch frontends
* remove trailing whitespaces
* fix lint issues
* fix space issue in test script
* fix func definition issue
* set relax var shape to fix the unity issue
* fix format issue in input declaration
* fix lint issue
* fix cpp lints
* ix cpp lint issue in manipulate file
* fix wrong input in struct info test script
* add one more mapping for meshgrid in exported program
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki.>
---
include/tvm/relax/attrs/manipulate.h | 9 ++
.../frontend/torch/base_fx_graph_translator.py | 20 ++++
.../frontend/torch/exported_program_translator.py | 2 +
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/manipulate.py | 23 +++++
.../tvm/relax/transform/legalize_ops/manipulate.py | 19 ++++
python/tvm/script/ir_builder/relax/ir.py | 2 +
src/relax/op/tensor/manipulate.cc | 103 +++++++++++++++++++++
src/relax/op/tensor/manipulate.h | 8 ++
.../relax/test_frontend_from_exported_program.py | 53 +++++++++++
tests/python/relax/test_frontend_from_fx.py | 60 ++++++++++++
tests/python/relax/test_op_manipulate.py | 45 +++++++++
13 files changed, 346 insertions(+)
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index 943d2f4d0d..2993223079 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -196,6 +196,15 @@ struct IndexPutAttrs : public
tvm::AttrsNode<IndexPutAttrs> {
}
}; // struct IndexPutAttrs
+/*! \brief Attribute used in meshgrid operator */
+struct MeshgridAttrs : public tvm::AttrsNode<MeshgridAttrs> {
+ Optional<String> indexing;
+
+ TVM_DECLARE_ATTRS(MeshgridAttrs, "relax.attrs.MeshgridAttrs") {
+ TVM_ATTR_FIELD(indexing).describe("Specifies how the grid dimensions are
ordered.");
+ }
+};
+
/*! \brief Attributes used in scatter_elements operators */
struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 48869767ad..0b48a015e1 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1179,6 +1179,26 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
indices = args[1]
return self.block_builder.emit(relax.op.index_tensor(args[0], indices))
+ def _meshgrid(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ indexing = args[1] if len(node.args) > 1 else
node.kwargs.get("indexing", "ij")
+ input_list = args[0]
+
+ # Single input: return as-is, meshgrid not applicable.
+ if len(input_list) == 1:
+ return input_list
+ new_inputs = []
+ for i, item in enumerate(input_list):
+ if item.struct_info.ndim == 1:
+ new_inputs.append(item)
+ elif item.struct_info.ndim == 0: # Change scalar value into 1D
+ const_tensor = relax.op.reshape(item, (1,))
+ new_inputs.append(const_tensor)
+ else:
+ raise TypeError(f"Unsupported meshgrid input type at index
{i}: {type(item)}")
+
+ return self.block_builder.emit(relax.op.meshgrid(new_inputs,
indexing=indexing))
+
def _permute(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index df532fd1ea..8b584906c8 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -439,6 +439,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"gather.default": self._gather,
"index.Tensor": self._index_tensor,
"index_put_.default": self._index_put,
+ "meshgrid.indexing": self._meshgrid,
+ "meshgrid.default": self._meshgrid,
"narrow.default": self._narrow,
"permute.default": self._permute,
"repeat.default": self._repeat,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 5f65f86a43..f1223b6243 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -804,6 +804,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"flip": self._flip,
"gather": self._gather,
"index_put_": self._index_put,
+ "meshgrid": self._meshgrid,
"narrow": self._narrow,
"numel": self._numel,
"permute": self._permute,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 7b8c34b641..be5306c9f4 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -97,6 +97,7 @@ from .manipulate import (
gather_nd,
index_put,
index_tensor,
+ meshgrid,
layout_transform,
one_hot,
permute_dims,
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index 13334d1479..b52aced59a 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -646,6 +646,29 @@ def index_put(
return _ffi_api.index_put(data, indices, values, accumulate) # type:
ignore
+def meshgrid(tensors: Union[Expr, List[Expr]], indexing: Optional[str] = "ij")
-> Expr:
+ """Generate coordinate grids from input tensors.
+
+ Parameters
+ ----------
+ tensors : Union[relax.Expr, List[relax.Expr]]
+ An Expr in Tuple type, containing 1D tensors (or scalars promoted to
1D)
+ to generate coordinate grids from, or a list of such tensors.
+
+ indexing : Optional[str]
+ The indexing mode, either "ij" (matrix indexing) or "xy" (Cartesian
indexing).
+ Defaults to "ij".
+
+ Returns
+ -------
+ result : relax.Expr
+ A Tuple of tensors representing the coordinate grids.
+ """
+ if isinstance(tensors, (list, tuple)):
+ tensors = RxTuple(tensors)
+ return _ffi_api.meshgrid(tensors, indexing)
+
+
def scatter_elements(
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str =
"update"
):
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index a66b60c013..04b16d4db3 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -215,6 +215,25 @@ def _index_put(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.meshgrid")
+def _meshgrid(bb: BlockBuilder, call: Call) -> Expr:
+ t = call.args[0]
+ n_field = len(t.struct_info.fields)
+ while isinstance(t, Var):
+ binding = bb.lookup_binding(t)
+ if not isinstance(binding, (Tuple, Var)):
+ break
+ t = binding
+
+ assert isinstance(t, (Tuple, Var))
+ fields = (
+ t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for
i in range(n_field)]
+ )
+ return bb.call_te(
+ topi.meshgrid, fields, "ij" if call.attrs.indexing is None else
call.attrs.indexing
+ )
+
+
@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index d2952ed8e0..6d5dbc20cb 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -125,6 +125,7 @@ from tvm.relax.op import (
maximum,
mean,
memory,
+ meshgrid,
min,
minimum,
mod,
@@ -811,6 +812,7 @@ __all__ = [
"maximum",
"mean",
"memory",
+ "meshgrid",
"metal",
"min",
"minimum",
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 482ebe5cac..dd46f23974 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2095,6 +2095,109 @@ TVM_REGISTER_OP("relax.index_put")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoIndexPut)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.meshgrid */
+TVM_REGISTER_NODE_TYPE(MeshgridAttrs);
+
+Expr meshgrid(Expr tensors, Optional<String> indexing) {
+ ObjectPtr<MeshgridAttrs> attrs = make_object<MeshgridAttrs>();
+ attrs->indexing = indexing;
+ static const Op& op = Op::Get("relax.meshgrid");
+ return Call(op, {std::move(tensors)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.meshgrid").set_body_typed(meshgrid);
+
+StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) {
+ if (call->args.size() != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call) << "meshgrid op expects 1 Tuple
input argument.");
+ }
+ Array<TensorStructInfo> input_sinfo = GetTensorStructInfoFromTuple(call,
ctx, call->args[0]);
+
+ int n_inputs = input_sinfo.size();
+
+ if (n_inputs == 0) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "meshgrid expects at least one 1D tensor in the input
Tuple.");
+ }
+
+ std::vector<PrimExpr> lengths;
+ DataType common_dtype = DataType::Void();
+ bool shape_unknown = false;
+ Optional<VDevice> vdev = NullOpt;
+ bool vdevice_unknown = false;
+
+ for (int i = 0; i < n_inputs; ++i) {
+ const TensorStructInfo& sinfo = input_sinfo[i];
+
+ if (sinfo->ndim != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "meshgrid expects each input tensor to be 1D. Got
ndim = " << sinfo->ndim
+ << " at index " << i);
+ }
+
+ if (sinfo->dtype.is_void()) {
+ continue;
+ } else if (common_dtype.is_void()) {
+ common_dtype = sinfo->dtype;
+ } else if (sinfo->dtype != common_dtype) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "meshgrid expects all input tensors to have the same
dtype. Found "
+ << sinfo->dtype << " and " << common_dtype);
+ }
+
+ const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
+ if (shape_expr && shape_expr->values.size() == 1) {
+ lengths.push_back(shape_expr->values[0]);
+ } else {
+ shape_unknown = true;
+ }
+
+ if (!vdevice_unknown) {
+ if (sinfo->vdevice.defined()) {
+ if (!vdev.defined()) {
+ vdev = sinfo->vdevice.value();
+ } else if (sinfo->vdevice.value() != vdev) {
+ vdevice_unknown = true;
+ }
+ }
+ }
+ }
+
+ Array<PrimExpr> out_shape;
+ if (!shape_unknown && lengths.size() == static_cast<size_t>(n_inputs)) {
+ for (const PrimExpr& dim : lengths) {
+ out_shape.push_back(dim);
+ }
+ }
+
+ Array<StructInfo> out_fields;
+ for (int i = 0; i < n_inputs; ++i) {
+ if (!out_shape.empty()) {
+ if (!vdevice_unknown) {
+ out_fields.push_back(TensorStructInfo(ShapeExpr(out_shape),
common_dtype, vdev));
+ } else {
+ out_fields.push_back(TensorStructInfo(ShapeExpr(out_shape),
common_dtype));
+ }
+ } else {
+ if (!vdevice_unknown) {
+ out_fields.push_back(TensorStructInfo(common_dtype, n_inputs, vdev));
+ } else {
+ out_fields.push_back(TensorStructInfo(common_dtype, n_inputs));
+ }
+ }
+ }
+
+ return TupleStructInfo(out_fields);
+}
+
+TVM_REGISTER_OP("relax.meshgrid")
+ .set_attrs_type<MeshgridAttrs>()
+ .set_num_inputs(1)
+ .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMeshgrid)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.scatter_elements */
TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 2e4c92c150..12d70da72a 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -231,6 +231,14 @@ Expr index_tensor(Expr data, Expr indices);
*/
Expr index_put(Expr data, Expr indices, Expr values, bool accumulate = false);
+/*!
+ * \brief Generate coordinate grids from input 1D tensors.
+ * \param tensors A tuple of 1D tensors representing coordinate vectors.
+ * \param indexing Indexing mode, either "ij" (matrix indexing) or "xy"
(Cartesian indexing).
+ * \return A tuple of tensors representing the coordinate grids.
+ */
+Expr meshgrid(Expr tensors, Optional<String> indexing = String("ij"));
+
/*!
* \brief Scatter updates into an array according to indices.
* \param data The input tensor.
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index f0bb33964e..b07070ddc9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2920,6 +2920,59 @@ def test_flatten():
verify_model(Flatten(), example_args, {}, expected1)
+def test_meshgrid():
+ class Meshgrid1(Module):
+ def forward(self, input1, input2):
+ return torch.meshgrid((input1, input2), indexing="ij")
+
+ class Meshgrid2(Module):
+ def forward(self, input1, input2):
+ return torch.meshgrid((input1, input2), indexing="xy")
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
+ ) = R.meshgrid((input1, input2), indexing="ij")
+ lv1: R.Tensor((3, 3), dtype="float32") = lv[0]
+ lv2: R.Tensor((3, 3), dtype="float32") = lv[1]
+ gv: R.Tuple(
+ R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
+ ) = (lv1, lv2)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
+ ) = R.meshgrid((input1, input2), indexing="xy")
+ lv1: R.Tensor((3, 3), dtype="float32") = lv[0]
+ lv2: R.Tensor((3, 3), dtype="float32") = lv[1]
+ gv: R.Tuple(
+ R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
+ ) = (lv1, lv2)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(3, dtype=torch.float32),
+ torch.randn(3, dtype=torch.float32),
+ )
+ verify_model(Meshgrid1(), example_args, {}, expected1)
+ verify_model(Meshgrid2(), example_args, {}, expected2)
+
+
def test_permute():
class Permute1(Module):
def forward(self, x):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 490a2309aa..2bb2a84441 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3528,6 +3528,66 @@ def test_datatype():
verify_model(AsType(), input_info, {}, expected1)
+def test_meshgrid():
+ input_infos = [
+ (
+ [
+ 3,
+ ],
+ "float32",
+ ),
+ (
+ [
+ 3,
+ ],
+ "float32",
+ ),
+ ]
+
+ class Meshgrid1(Module):
+ def forward(self, input1, input2):
+ return torch.meshgrid((input1, input2), indexing="ij")
+
+ class Meshgrid2(Module):
+ def forward(self, input1, input2):
+ return torch.meshgrid((input1, input2), indexing="xy")
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((3,), dtype="float32"), inp_1: R.Tensor((3,),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
+ ) = R.meshgrid((inp_0, inp_1), indexing="ij")
+ gv: R.Tuple(
+ R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
+ ) = lv
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((3,), dtype="float32"), inp_1: R.Tensor((3,),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
+ ) = R.meshgrid((inp_0, inp_1), indexing="xy")
+ gv: R.Tuple(
+ R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
+ ) = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Meshgrid1(), input_infos, {}, expected1)
+ verify_model(Meshgrid2(), input_infos, {}, expected2)
+
+
def test_permute():
input_info = [([1, 2, 3, 4], "float32")]
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index 28e762d9a4..2e171a0a5b 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -3502,6 +3502,51 @@ def test_scatter_nd_infer_struct_info():
)
+def test_meshgrid_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ t0 = relax.Var("t0", R.Tensor((3,), "float32"))
+ t1 = relax.Var("t1", R.Tensor((4,), "float32"))
+ t2 = relax.Var("t2", R.Tensor("float32", ndim=1))
+ t3 = relax.Var("t3", R.Tensor((5,), "float32", vdev0))
+
+ _check_inference(
+ bb,
+ relax.op.meshgrid((t0, t1), indexing="ij"),
+ relax.TupleStructInfo(
+ [relax.TensorStructInfo((3, 4), "float32"),
relax.TensorStructInfo((3, 4), "float32")]
+ ),
+ )
+
+ _check_inference(
+ bb,
+ relax.op.meshgrid((t3, t1), indexing="ij"),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((5, 4), "float32", vdev0),
+ relax.TensorStructInfo((5, 4), "float32", vdev0),
+ ]
+ ),
+ )
+
+ _check_inference(
+ bb,
+ relax.op.meshgrid((t2, t1), indexing="xy"),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(dtype="float32", ndim=2),
+ relax.TensorStructInfo(dtype="float32", ndim=2),
+ ]
+ ),
+ )
+
+ _check_inference(
+ bb,
+ relax.op.meshgrid((t0,), indexing="ij"),
+ relax.TupleStructInfo([relax.TensorStructInfo((3,), "float32")]),
+ )
+
+
def test_one_hot_infer_struct_info():
bb = relax.BlockBuilder()