This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 26b107fa12 [Relax][PyTorch] Add support for masked_select (#18535)
26b107fa12 is described below
commit 26b107fa12672c3b958da222fc87755a69d64c42
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Mon Dec 8 03:59:25 2025 +0800
[Relax][PyTorch] Add support for masked_select (#18535)
## How
Add support for masked_select
---
.../frontend/torch/base_fx_graph_translator.py | 21 ++++++++++++
.../frontend/torch/exported_program_translator.py | 11 +++++++
python/tvm/script/ir_builder/relax/ir.py | 2 ++
.../relax/test_frontend_from_exported_program.py | 37 ++++++++++++++++++++++
4 files changed, 71 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 7ebb95c136..471d4209d7 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -23,6 +23,7 @@ from functools import reduce
import math
from typing import Callable, Dict, Optional, Tuple, Union, List
+import tvm
from tvm import relax, tir
@@ -2385,6 +2386,26 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.where(mask, values, x))
+ def _masked_select(self, node: fx.Node) -> relax.Var:
+ data = self.env[node.args[0]]
+ mask = self.env[node.args[1]]
+
+ data_shape = self.shape_of(data)
+ mask_shape = self.shape_of(mask)
+ shapes_equal = tvm.ir.structural_equal(data_shape, mask_shape)
+
+ if not shapes_equal:
+ mask = self.block_builder.emit(relax.op.broadcast_to(mask,
data_shape))
+
+ data_flat = self.block_builder.emit(relax.op.reshape(data, [-1]))
+ mask_flat = self.block_builder.emit(relax.op.reshape(mask, [-1]))
+ indices = self.block_builder.emit(relax.op.nonzero(mask_flat))
+ indices_1d = self.block_builder.emit(relax.op.squeeze(indices,
axis=[0]))
+
+ result = self.block_builder.emit(relax.op.take(data_flat, indices_1d,
axis=0))
+
+ return result
+
def _new_ones(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
self_var = args[0]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 641e16f599..3e2274e551 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1153,6 +1153,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
return self.block_builder.emit(relax.op.reshape(x, size))
+ ########## Symbolic Shape Constraints ##########
+
+ def _symbolic_comparison(self, _: fx.Node) -> relax.Expr:
+ return self.block_builder.emit(relax.const(True, dtype="bool"))
+
########## Others ##########
def create_convert_map(
@@ -1457,6 +1462,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"linspace.default": self._linspace,
"masked_fill.Scalar": self._masked_fill,
"masked_fill_.Scalar": self._inplace_masked_fill,
+ "masked_select.default": self._masked_select,
"new_ones.default": self._new_ones,
"new_zeros.default": self._new_zeros,
"one_hot.default": self._one_hot,
@@ -1477,6 +1483,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"item.default": self._item,
"sym_size.int": self._sym_size_int,
"_local_scalar_dense.default": self._item,
+ # symbolic shape constraints (no-ops for compilation)
+ "sym_constrain_range_for_size.default": lambda node:
self.env[node.args[0]],
+ "_assert_scalar.default": lambda node: self.env[node.args[0]],
+ "ge": self._symbolic_comparison,
+ "le": self._symbolic_comparison,
}
def _process_derived_symbol(
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index f221a13089..141361a729 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -137,6 +137,7 @@ from tvm.relax.op import (
multiply,
negative,
nn,
+ nonzero,
not_equal,
null_value,
ones,
@@ -882,6 +883,7 @@ __all__ = [
"multinomial_from_uniform",
"multiply",
"negative",
+ "nonzero",
"not_equal",
"null_value",
"ones",
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 68567e1fc8..74ad2329fe 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6231,6 +6231,43 @@ def test_masked_fill_inplace():
verify_model(Masked_Fill_Inplace(), example_args, {}, Expected)
+def test_masked_select():
+ class MaskedSelect(Module):
+ def forward(self, data: torch.Tensor, mask: torch.Tensor):
+ return torch.masked_select(data, mask)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ data: R.Tensor((2, 3), dtype="float32"), mask: R.Tensor((2, 3),
dtype="bool")
+ ) -> R.Tuple(R.Tensor(dtype="float32", ndim=1)):
+ R.func_attr(
+ {
+ "tir_var_lower_bound": {"u0": 0, "u1": 0},
+ "tir_var_upper_bound": {"u0": 6, "u1": 6},
+ }
+ )
+ with R.dataflow():
+ lv: R.Tensor((6,), dtype="float32") = R.reshape(data,
R.shape([6]))
+ lv1: R.Tensor((6,), dtype="bool") = R.reshape(mask,
R.shape([6]))
+ lv2: R.Tensor(dtype="int64", ndim=2) = R.nonzero(lv1)
+ lv3: R.Tensor(dtype="int64", ndim=1) = R.squeeze(lv2, axis=[0])
+ lv4: R.Tensor(dtype="float32", ndim=1) = R.take(lv, lv3,
axis=0, mode="fast")
+ lv5: R.Tensor((), dtype="int64") = R.const(0, "int64")
+ lv6: R.Tensor((), dtype="bool") = R.const(True, "bool")
+ lv7: R.Tensor((), dtype="bool") = R.const(True, "bool")
+ gv: R.Tuple(R.Tensor(dtype="float32", ndim=1)) = (lv4,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(2, 3, dtype=torch.float32),
+ torch.tensor([[True, False, True], [False, True, False]]),
+ )
+ verify_model(MaskedSelect(), example_args, {}, Expected)
+
+
def test_new_ones():
class NewOnes(Module):
def forward(self, x):