This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9898909392 [Relax][PyTorch] Cast non-bool inputs to bool in 
logical_not converter (#19645)
9898909392 is described below

commit 9898909392bcdf9b155da49a5173b6efaf6913f6
Author: Javier De Jesus <[email protected]>
AuthorDate: Mon Jun 1 21:14:15 2026 +0200

    [Relax][PyTorch] Cast non-bool inputs to bool in logical_not converter 
(#19645)
    
    ### Motivation
    
    `torch.logical_not` accepts an input tensor of any dtype (treating any
    nonzero
    element as `True`) and always returns a `bool` tensor.
    
    The PyTorch frontend previously lowered it with
    `self._unary_op(relax.op.logical_not)`.
    `relax.op.logical_not` is a unary arithmetic op that passes its input
    dtype through,
    so a non-bool input (for example `float32`) produced a `float32` result
    instead of
    the `bool` result PyTorch returns. This is a dtype mismatch against the
    reference
    PyTorch semantics for both the FX and ExportedProgram frontends.
    
    ### Changes
    
    - Add a shared `_logical_not` converter in `BaseFXGraphImporter` that
    casts non-bool
    inputs to `bool` before applying `relax.op.logical_not`. Bool inputs are
    passed
      through unchanged (no redundant cast).
    - Point the `logical_not` (FX) and `logical_not.default`
    (ExportedProgram)
      registrations at the new converter.
    - Update the FX test and add a standalone ExportedProgram
    `test_logical_not` to assert
    the corrected IR (`astype` to bool, then `logical_not`, producing a
    `bool` output).
    
    ### Notes
    
    The cast to `bool` lowers to an elementwise nonzero test, so it matches
    PyTorch's
    "nonzero is True" semantics for float, integer, and NaN inputs.
---
 .../frontend/torch/base_fx_graph_translator.py     |  8 ++++++++
 .../frontend/torch/exported_program_translator.py  |  2 +-
 python/tvm/relax/frontend/torch/fx_translator.py   |  2 +-
 .../relax/test_frontend_from_exported_program.py   | 23 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        |  7 ++++---
 5 files changed, 37 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index e9bddc4500..a2ebed0480 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -389,6 +389,14 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
         return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
 
+    def _logical_not(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        # torch.logical_not accepts any dtype (treating nonzero as True) and 
returns bool, but
+        # relax.op.logical_not requires a boolean input, so cast non-bool 
inputs to bool first.
+        if x.struct_info.dtype != "bool":
+            x = self.block_builder.emit(relax.op.astype(x, "bool"))
+        return self.block_builder.emit(relax.op.logical_not(x))
+
     def _prelu(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         alpha = self.env[node.args[1]]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 596dc60f55..26f5a5918c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1551,7 +1551,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "log2.default": self._log2,
             "log10.default": self._log10,
             "log1p.default": self._log1p,
-            "logical_not.default": self._unary_op(relax.op.logical_not),
+            "logical_not.default": self._logical_not,
             "logical_and.default": self._binary_op(relax.op.logical_and, 
operator.and_),
             "log_softmax.int": self._log_softmax,
             "_log_softmax.default": self._log_softmax,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index d4dd6902ae..9d27f62b42 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -875,7 +875,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "log2": self._log2,
             "log10": self._log10,
             "log1p": self._log1p,
-            "logical_not": self._unary_op(relax.op.logical_not),
+            "logical_not": self._logical_not,
             "log_softmax": self._log_softmax,
             "neg": self._unary_op(relax.op.negative),
             "pad": self._pad,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 6b758c1ba7..d1bdad7578 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1062,6 +1062,29 @@ def test_logaddexp():
     verify_model(LogAddExp(), example_args, {}, expected)
 
 
+def test_logical_not():
+    class LogicalNot(Module):
+        def forward(self, input):
+            return torch.logical_not(input)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(input: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple(
+            R.Tensor((1, 3, 10, 10), dtype="bool")
+        ):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(input, 
dtype="bool")
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_not(lv)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(LogicalNot(), example_args, {}, expected)
+
+
 def test_logsoftmax():
     class LogSoftmax(Module):
         def __init__(self):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 410875985e..1bf71fb6eb 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3195,11 +3195,12 @@ def test_extended_unary_ops():
     class expected_logical_not:
         @R.function
         def main(inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor(
-            (1, 3, 10, 10), dtype="float32"
+            (1, 3, 10, 10), dtype="bool"
         ):
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.logical_not(inp_0)
-                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(inp_0, 
dtype="bool")
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_not(lv)
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv1
                 R.output(gv)
             return gv
 

Reply via email to