This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2de5b1e08c [Relax][PyTorch] Add torch.isin Op Support for Exported 
Program and FX graph (#17878)
2de5b1e08c is described below

commit 2de5b1e08c8e8541412a701f3714a6cc0dd69d10
Author: Deivanayaki S <[email protected]>
AuthorDate: Sat Apr 26 09:25:23 2025 +0530

    [Relax][PyTorch] Add torch.isin Op Support for Exported Program and FX 
graph (#17878)
    
    * add torch.isin op support into torch frontends
    
    * fix lint issues in test script
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
 .../frontend/torch/base_fx_graph_translator.py     | 14 ++++++++++
 .../frontend/torch/exported_program_translator.py  |  1 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  1 +
 .../relax/test_frontend_from_exported_program.py   | 31 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 29 ++++++++++++++++++++
 5 files changed, 76 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 3e81ff1f0b..c1a1a61398 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -417,6 +417,20 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return self.block_builder.emit(relax.op.subtract(rhs, lhs))
 
+    def _isin(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        elements = args[0]
+        test_elements = args[1]
+
+        expanded_elements = relax.op.expand_dims(elements, axis=-1)
+        flattened_test_elements = relax.op.reshape(test_elements, (-1,))
+
+        comparison = relax.op.equal(expanded_elements, flattened_test_elements)
+        summed = relax.op.sum(comparison, axis=-1)
+        result = relax.op.greater(summed, relax.const(0, 
dtype=elements.struct_info.dtype))
+
+        return self.block_builder.emit(result)
+
     ########## Neural Network ##########
 
     def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index a3ab575c4b..88f6dd538d 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -299,6 +299,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "hardtanh_.default": self._hardtanh,
             "isfinite.default": self._unary_op(relax.op.isfinite),
             "isinf.default": self._unary_op(relax.op.isinf),
+            "isin.Tensor_Tensor": self._isin,
             "isnan.default": self._unary_op(relax.op.isnan),
             "leaky_relu.default": self._leakyrelu,
             "leaky_relu_.default": self._leakyrelu,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 18dba2d988..0d3dafc8d5 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -693,6 +693,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "hardtanh": self._hardtanh,
             "isfinite": self._unary_op(relax.op.isfinite),
             "isinf": self._unary_op(relax.op.isinf),
+            "isin": self._isin,
             "isnan": self._unary_op(relax.op.isnan),
             "leaky_relu": self._leakyrelu,
             "log": self._unary_op(relax.op.log),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index e3b6f4ad9c..8cc3dde397 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1060,6 +1060,37 @@ def test_binary3():
     verify_model(RSub2(), example_args2, {}, expected_rsub2)
 
 
+# IsIn
+
+
+def test_isin():
+    class IsInModel(torch.nn.Module):
+        def forward(self, x, test_elements):
+            return torch.isin(x, test_elements)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            x: R.Tensor((10, 10), dtype="float32"), test_elements: 
R.Tensor((8,), dtype="float32")
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
+            with R.dataflow():
+                lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(x, 
axis=[-1])
+                lv1: R.Tensor((8,), dtype="float32") = 
R.reshape(test_elements, R.shape([8]))
+                lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1)
+                lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1], 
keepdims=False)
+                lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3, 
R.const(0.0, "float32"))
+                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(10, 10, dtype=torch.float32),
+        torch.randn(8, dtype=torch.float32),
+    )
+    verify_model(IsInModel(), example_args, {}, expected)
+
+
 def test_batchnorm2d():
     class BatchNorm2d(Module):
         def __init__(self):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 4003202d4f..48c2cec8c0 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1868,6 +1868,35 @@ def test_rsub():
     verify_model(RSub2(), input_info2, {}, expected_rsub2)
 
 
+# IsIn
+
+
+def test_isin():
+    input_info = [([10, 10], "float32"), ([8], "float32")]
+
+    class IsInModel(torch.nn.Module):
+        def forward(self, x, test_elements):
+            return torch.isin(x, test_elements)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((10, 10), dtype="float32"), inp_1: R.Tensor((8,), 
dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="bool"):
+            with R.dataflow():
+                lv: R.Tensor((10, 10, 1), dtype="float32") = 
R.expand_dims(inp_0, axis=[-1])
+                lv1: R.Tensor((8,), dtype="float32") = R.reshape(inp_1, 
R.shape([8]))
+                lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1)
+                lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1], 
keepdims=False)
+                lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3, 
R.const(0.0, "float32"))
+                gv: R.Tensor((10, 10), dtype="bool") = lv4
+                R.output(gv)
+            return gv
+
+    verify_model(IsInModel(), input_info, {}, expected)
+
+
 def test_size():
     input_info = [([1, 3, 10, 10], "float32")]
 

Reply via email to