This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4244a8658a [Relax][PyTorch] Add boolean tensor support for max 
operation and corresponding test case (#18530)
4244a8658a is described below

commit 4244a8658ae3f4d9f4a77038a5cc0e5514a63080
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Nov 30 02:56:58 2025 +0900

    [Relax][PyTorch] Add boolean tensor support for max operation and 
corresponding test case (#18530)
    
    As per title.
    ref: https://github.com/apache/tvm/pull/18524#discussion_r2573023355
---
 .../frontend/torch/base_fx_graph_translator.py     |   6 +
 .../relax/test_frontend_from_exported_program.py   | 176 ++++++---------------
 2 files changed, 57 insertions(+), 125 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index e554648c41..33a22b34fc 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1634,6 +1634,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
         keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
 
+        # max doesn't support boolean tensors directly, so we compute it in 
int8 and cast back
+        if x.struct_info.dtype == "bool":
+            x = relax.op.astype(x, "int8")
+            ret = relax.op.max(x, dim, keepdims=keepdim)
+            return self.block_builder.emit(relax.op.astype(ret, "bool"))
+
         # For boolean tensors, any is equivalent to max (checking if any 
element is True)
         return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim))
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 662df5e76a..7397b3f21a 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1693,8 +1693,10 @@ def test_isin():
             with R.dataflow():
                 lv: R.Tensor((10, 10, 1), dtype="float32") = R.reshape(x, 
R.shape([10, 10, 1]))
                 lv1: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, 
test_elements)
-                lv2: R.Tensor((10, 10), dtype="bool") = R.max(lv1, axis=[-1], 
keepdims=False)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv2,)
+                lv2: R.Tensor((10, 10, 8), dtype="int8") = R.astype(lv1, 
dtype="int8")
+                lv3: R.Tensor((10, 10), dtype="int8") = R.max(lv2, axis=[-1], 
keepdims=False)
+                lv4: R.Tensor((10, 10), dtype="bool") = R.astype(lv3, 
dtype="bool")
+                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,)
                 R.output(gv)
             return gv
 
@@ -4118,71 +4120,22 @@ def test_scaled_dot_product_attention():
             v: R.Tensor((32, 8, 128, 64), dtype="float32"),
         ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply(
-                    q, R.const(0.35355338454246521, "float32")
+                lv: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    q, axes=[0, 2, 1, 3]
                 )
-                lv1: R.Tensor((32, 8, 64, 128), dtype="float32") = 
R.permute_dims(
-                    k, axes=[0, 1, 3, 2]
+                lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    k, axes=[0, 2, 1, 3]
                 )
-                lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply(
-                    lv1, R.const(0.35355338454246521, "float32")
+                lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    v, axes=[0, 2, 1, 3]
                 )
-                lv3: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.broadcast_to(
-                    lv, R.shape([32, 8, 128, 64])
+                lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.nn.attention(
+                    lv, lv1, lv2, scale=None, causal_mask=None, 
window_size=None
                 )
-                lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
-                    lv3, R.shape([256, 128, 64])
+                lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.permute_dims(
+                    lv3, axes=[0, 2, 1, 3]
                 )
-                lv5: R.Tensor((32, 8, 64, 128), dtype="float32") = 
R.broadcast_to(
-                    lv2, R.shape([32, 8, 64, 128])
-                )
-                lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape(
-                    lv5, R.shape([256, 64, 128])
-                )
-                lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul(
-                    lv4, lv6, out_dtype="float32"
-                )
-                lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape(
-                    lv7, R.shape([32, 8, 128, 128])
-                )
-                lv9: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.nn.softmax(lv8, axis=-1)
-                lv10: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal(
-                    lv8, R.const(float("-inf"), "float32")
-                )
-                lv11: R.Tensor((32, 8, 128, 128), dtype="bool") = 
R.logical_not(lv10)
-                lv12: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max(
-                    lv11, axis=[-1], keepdims=True
-                )
-                lv13: R.Tensor((32, 8, 128, 1), dtype="bool") = 
R.logical_not(lv12)
-                lv14: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.full_like(
-                    lv9, R.const(0, "int32"), dtype="void"
-                )
-                lv15: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.where(lv13, lv14, lv9)
-                lv16: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.broadcast_to(
-                    lv15, R.shape([32, 8, 128, 128])
-                )
-                lv17: R.Tensor((256, 128, 128), dtype="float32") = R.reshape(
-                    lv16, R.shape([256, 128, 128])
-                )
-                lv18: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.broadcast_to(
-                    v, R.shape([32, 8, 128, 64])
-                )
-                lv19: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
-                    lv18, R.shape([256, 128, 64])
-                )
-                lv20: R.Tensor((256, 128, 64), dtype="float32") = R.matmul(
-                    lv17, lv19, out_dtype="float32"
-                )
-                lv21: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape(
-                    lv20, R.shape([32, 8, 128, 64])
-                )
-                lv22: R.Tensor((128, 32, 8, 64), dtype="float32") = 
R.permute_dims(
-                    lv21, axes=[2, 0, 1, 3]
-                )
-                lv23: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.permute_dims(
-                    lv22, axes=[1, 2, 0, 3]
-                )
-                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = 
(lv23,)
+                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = 
(lv4,)
                 R.output(gv)
             return gv
 
@@ -4200,72 +4153,22 @@ def test_scaled_dot_product_attention():
             mask: R.Tensor((32, 8, 128, 128), dtype="float32"),
         ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply(
-                    q, R.const(0.35355338454246521, "float32")
-                )
-                lv1: R.Tensor((32, 8, 64, 128), dtype="float32") = 
R.permute_dims(
-                    k, axes=[0, 1, 3, 2]
-                )
-                lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply(
-                    lv1, R.const(0.35355338454246521, "float32")
-                )
-                lv3: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.broadcast_to(
-                    lv, R.shape([32, 8, 128, 64])
-                )
-                lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
-                    lv3, R.shape([256, 128, 64])
-                )
-                lv5: R.Tensor((32, 8, 64, 128), dtype="float32") = 
R.broadcast_to(
-                    lv2, R.shape([32, 8, 64, 128])
-                )
-                lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape(
-                    lv5, R.shape([256, 64, 128])
-                )
-                lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul(
-                    lv4, lv6, out_dtype="float32"
-                )
-                lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape(
-                    lv7, R.shape([32, 8, 128, 128])
+                lv: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    q, axes=[0, 2, 1, 3]
                 )
-                lv9: R.Tensor((32, 8, 128, 128), dtype="float32") = R.add(lv8, 
mask)
-                lv10: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.nn.softmax(lv9, axis=-1)
-                lv11: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal(
-                    lv9, R.const(float("-inf"), "float32")
+                lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    k, axes=[0, 2, 1, 3]
                 )
-                lv12: R.Tensor((32, 8, 128, 128), dtype="bool") = 
R.logical_not(lv11)
-                lv13: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max(
-                    lv12, axis=[-1], keepdims=True
+                lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    v, axes=[0, 2, 1, 3]
                 )
-                lv14: R.Tensor((32, 8, 128, 1), dtype="bool") = 
R.logical_not(lv13)
-                lv15: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.full_like(
-                    lv10, R.const(0, "int32"), dtype="void"
+                lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.nn.attention_bias(
+                    lv, lv1, lv2, mask, scale=None, causal_mask=None, 
window_size=None
                 )
-                lv16: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.where(lv14, lv15, lv10)
-                lv17: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.broadcast_to(
-                    lv16, R.shape([32, 8, 128, 128])
+                lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.permute_dims(
+                    lv3, axes=[0, 2, 1, 3]
                 )
-                lv18: R.Tensor((256, 128, 128), dtype="float32") = R.reshape(
-                    lv17, R.shape([256, 128, 128])
-                )
-                lv19: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.broadcast_to(
-                    v, R.shape([32, 8, 128, 64])
-                )
-                lv20: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
-                    lv19, R.shape([256, 128, 64])
-                )
-                lv21: R.Tensor((256, 128, 64), dtype="float32") = R.matmul(
-                    lv18, lv20, out_dtype="float32"
-                )
-                lv22: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape(
-                    lv21, R.shape([32, 8, 128, 64])
-                )
-                lv23: R.Tensor((128, 32, 8, 64), dtype="float32") = 
R.permute_dims(
-                    lv22, axes=[2, 0, 1, 3]
-                )
-                lv24: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.permute_dims(
-                    lv23, axes=[1, 2, 0, 3]
-                )
-                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = 
(lv24,)
+                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = 
(lv4,)
                 R.output(gv)
             return gv
 
@@ -4278,7 +4181,7 @@ def test_scaled_dot_product_attention():
         ),
         {},
         Expected1,
-        run_ep_decomposition=True,
+        run_ep_decomposition=False,
     )
 
     verify_model(
@@ -4291,7 +4194,7 @@ def test_scaled_dot_product_attention():
         ),
         {},
         Expected2,
-        run_ep_decomposition=True,
+        run_ep_decomposition=False,
     )
 
     # Test 2D input (seq_len, head_dim) - bug fix for #18441
@@ -7307,6 +7210,29 @@ def test_take():
     verify_model(Take(), example_args, {}, Expected)
 
 
+def test_any():
+    class AnyAten(torch.nn.Module):
+        def forward(self, x):
+            return torch.ops.aten.any(x, dim=1)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3), dtype="bool"),
+        ) -> R.Tuple(R.Tensor((2,), dtype="bool")):
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="int8") = relax.op.astype(x, 
dtype="int8")
+                lv2: R.Tensor((2,), dtype="int8") = relax.op.max(lv, axis=1, 
keepdims=False)
+                lv3: R.Tensor((2,), dtype="bool") = relax.op.astype(lv2, 
dtype="bool")
+                gv: R.Tuple(R.Tensor((2,), dtype="bool")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.tensor([[0, 0, 0], [0, 1, 0]], dtype=torch.bool),)
+    verify_model(AnyAten(), example_args, {}, Expected)
+
+
 def test_std():
     class Std(Module):
         def forward(self, x):

Reply via email to