This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f5e8a3acb3 [Relax][PyTorch] Add logical_or and logical_xor converters 
(#19756)
f5e8a3acb3 is described below

commit f5e8a3acb3f157f8a1faf689fe0ece98a67c450c
Author: Javier De Jesus <[email protected]>
AuthorDate: Sat Jun 13 20:42:24 2026 +0200

    [Relax][PyTorch] Add logical_or and logical_xor converters (#19756)
    
    ### Motivation
    
    `torch.logical_or` and `torch.logical_xor` accept input tensors of any
    dtype
    (treating any nonzero element as `True`) and always return a `bool`
    tensor.
    
    Neither op was handled by the PyTorch frontend. The ExportedProgram
    frontend did
    not register `logical_or.default` / `logical_xor.default`, and the FX
    frontend
    did not register `logical_or` / `logical_xor`, so importing a model that
    uses
    either op failed early with `Unsupported function types`.
    
    This follows up on #19679 (`logical_and`) and addresses the explicit
    question
    raised in #19743: whether `logical_or` and `logical_xor` need the same
    handling.
    
    ### Changes
    
    - Add shared `_logical_or` and `_logical_xor` converters in
    `BaseFXGraphImporter`
    that cast non-bool operands to `bool` before applying
    `relax.op.logical_or` /
      `relax.op.logical_xor`. Bool operands are passed through unchanged (no
      redundant cast).
    - Register `logical_or.default` / `logical_xor.default`
    (ExportedProgram) and
      `logical_or` / `logical_xor` (FX), matching the existing `logical_and`
      converter.
    - Add standalone `test_logical_or` and `test_logical_xor` to both the FX
    and
    ExportedProgram test suites, asserting the corrected IR (`astype` to
    bool on
      each operand, then the logical op, producing a `bool` output).
    
    ### Notes
    
    The cast to `bool` lowers to an elementwise nonzero test, so it matches
    PyTorch's "nonzero is True" semantics for float, integer, and NaN
    inputs.
---
 .../frontend/torch/base_fx_graph_translator.py     | 22 +++++++++
 .../frontend/torch/exported_program_translator.py  |  2 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  2 +
 .../relax/test_frontend_from_exported_program.py   | 56 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 50 +++++++++++++++++++
 5 files changed, 132 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 1bf36069d1..4c3cdd464f 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -410,6 +410,28 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             x = self.block_builder.emit(relax.op.astype(x, "bool"))
         return self.block_builder.emit(relax.op.logical_not(x))
 
+    def _logical_or(self, node: fx.Node) -> relax.Var:
+        lhs = self.env[node.args[0]]
+        rhs = self.env[node.args[1]]
+        # torch.logical_or accepts any dtype (treating nonzero as True) and 
returns bool, but
+        # relax.op.logical_or requires boolean inputs, so cast non-bool inputs 
to bool first.
+        if lhs.struct_info.dtype != "bool":
+            lhs = self.block_builder.emit(relax.op.astype(lhs, "bool"))
+        if rhs.struct_info.dtype != "bool":
+            rhs = self.block_builder.emit(relax.op.astype(rhs, "bool"))
+        return self.block_builder.emit(relax.op.logical_or(lhs, rhs))
+
+    def _logical_xor(self, node: fx.Node) -> relax.Var:
+        lhs = self.env[node.args[0]]
+        rhs = self.env[node.args[1]]
+        # torch.logical_xor accepts any dtype (treating nonzero as True) and 
returns bool, but
+        # relax.op.logical_xor requires boolean inputs, so cast non-bool 
inputs to bool first.
+        if lhs.struct_info.dtype != "bool":
+            lhs = self.block_builder.emit(relax.op.astype(lhs, "bool"))
+        if rhs.struct_info.dtype != "bool":
+            rhs = self.block_builder.emit(relax.op.astype(rhs, "bool"))
+        return self.block_builder.emit(relax.op.logical_xor(lhs, rhs))
+
     def _prelu(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         alpha = self.env[node.args[1]]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 3edbc2adb4..6c9e3e3f5e 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1558,6 +1558,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "log1p.default": self._log1p,
             "logical_not.default": self._logical_not,
             "logical_and.default": self._logical_and,
+            "logical_or.default": self._logical_or,
+            "logical_xor.default": self._logical_xor,
             "log_softmax.int": self._log_softmax,
             "_log_softmax.default": self._log_softmax,
             "neg.default": self._unary_op(relax.op.negative),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 4af86068d7..66d17a5828 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -877,6 +877,8 @@ class TorchFXImporter(BaseFXGraphImporter):
             "log1p": self._log1p,
             "logical_and": self._logical_and,
             "logical_not": self._logical_not,
+            "logical_or": self._logical_or,
+            "logical_xor": self._logical_xor,
             "log_softmax": self._log_softmax,
             "neg": self._unary_op(relax.op.negative),
             "pad": self._pad,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 958e7ce054..b4f1c475d9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1066,6 +1066,62 @@ def test_logical_not():
     verify_model(LogicalNot(), example_args, {}, expected)
 
 
+def test_logical_or():
+    class LogicalOr(Module):
+        def forward(self, lhs, rhs):
+            return torch.logical_or(lhs, rhs)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(lhs, 
dtype="bool")
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(rhs, 
dtype="bool")
+                lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_or(lv, 
lv1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv2,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(1, 3, 10, 10, dtype=torch.float32),
+        torch.randn(1, 3, 10, 10, dtype=torch.float32),
+    )
+    verify_model(LogicalOr(), example_args, {}, expected)
+
+
+def test_logical_xor():
+    class LogicalXor(Module):
+        def forward(self, lhs, rhs):
+            return torch.logical_xor(lhs, rhs)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(lhs, 
dtype="bool")
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(rhs, 
dtype="bool")
+                lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.logical_xor(lv, lv1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv2,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(1, 3, 10, 10, dtype=torch.float32),
+        torch.randn(1, 3, 10, 10, dtype=torch.float32),
+    )
+    verify_model(LogicalXor(), example_args, {}, expected)
+
+
 def test_pow_integer():
     class Pow(Module):
         def forward(self, input):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index a5f799e6d6..cdb343e73a 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3553,6 +3553,56 @@ def test_logical_and():
     verify_model(LogicalAnd(), input_info, {}, expected)
 
 
+def test_logical_or():
+    input_info = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
+
+    class LogicalOr(Module):
+        def forward(self, lhs, rhs):
+            return torch.logical_or(lhs, rhs)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(lhs, 
dtype="bool")
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(rhs, 
dtype="bool")
+                lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_or(lv, 
lv1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv2
+                R.output(gv)
+            return gv
+
+    verify_model(LogicalOr(), input_info, {}, expected)
+
+
+def test_logical_xor():
+    input_info = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
+
+    class LogicalXor(Module):
+        def forward(self, lhs, rhs):
+            return torch.logical_xor(lhs, rhs)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(lhs, 
dtype="bool")
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(rhs, 
dtype="bool")
+                lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.logical_xor(lv, lv1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv2
+                R.output(gv)
+            return gv
+
+    verify_model(LogicalXor(), input_info, {}, expected)
+
+
 def test_pow_integer():
     input_info = [([4], "int64")]
 

Reply via email to