This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4244a8658a [Relax][PyTorch] Add boolean tensor support for max
operation and corresponding test case (#18530)
4244a8658a is described below
commit 4244a8658ae3f4d9f4a77038a5cc0e5514a63080
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Nov 30 02:56:58 2025 +0900
[Relax][PyTorch] Add boolean tensor support for max operation and
corresponding test case (#18530)
As per title.
ref: https://github.com/apache/tvm/pull/18524#discussion_r2573023355
---
.../frontend/torch/base_fx_graph_translator.py | 6 +
.../relax/test_frontend_from_exported_program.py | 176 ++++++---------------
2 files changed, 57 insertions(+), 125 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index e554648c41..33a22b34fc 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1634,6 +1634,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
+ # max doesn't support boolean tensors directly, so we compute it in
int8 and cast back
+ if x.struct_info.dtype == "bool":
+ x = relax.op.astype(x, "int8")
+ ret = relax.op.max(x, dim, keepdims=keepdim)
+ return self.block_builder.emit(relax.op.astype(ret, "bool"))
+
# For boolean tensors, any is equivalent to max (checking if any
element is True)
return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim))
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 662df5e76a..7397b3f21a 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1693,8 +1693,10 @@ def test_isin():
with R.dataflow():
lv: R.Tensor((10, 10, 1), dtype="float32") = R.reshape(x,
R.shape([10, 10, 1]))
lv1: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv,
test_elements)
- lv2: R.Tensor((10, 10), dtype="bool") = R.max(lv1, axis=[-1],
keepdims=False)
- gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv2,)
+ lv2: R.Tensor((10, 10, 8), dtype="int8") = R.astype(lv1,
dtype="int8")
+ lv3: R.Tensor((10, 10), dtype="int8") = R.max(lv2, axis=[-1],
keepdims=False)
+ lv4: R.Tensor((10, 10), dtype="bool") = R.astype(lv3,
dtype="bool")
+ gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,)
R.output(gv)
return gv
@@ -4118,71 +4120,22 @@ def test_scaled_dot_product_attention():
v: R.Tensor((32, 8, 128, 64), dtype="float32"),
) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply(
- q, R.const(0.35355338454246521, "float32")
+ lv: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ q, axes=[0, 2, 1, 3]
)
- lv1: R.Tensor((32, 8, 64, 128), dtype="float32") =
R.permute_dims(
- k, axes=[0, 1, 3, 2]
+ lv1: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ k, axes=[0, 2, 1, 3]
)
- lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply(
- lv1, R.const(0.35355338454246521, "float32")
+ lv2: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ v, axes=[0, 2, 1, 3]
)
- lv3: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.broadcast_to(
- lv, R.shape([32, 8, 128, 64])
+ lv3: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.nn.attention(
+ lv, lv1, lv2, scale=None, causal_mask=None,
window_size=None
)
- lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
- lv3, R.shape([256, 128, 64])
+ lv4: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.permute_dims(
+ lv3, axes=[0, 2, 1, 3]
)
- lv5: R.Tensor((32, 8, 64, 128), dtype="float32") =
R.broadcast_to(
- lv2, R.shape([32, 8, 64, 128])
- )
- lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape(
- lv5, R.shape([256, 64, 128])
- )
- lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul(
- lv4, lv6, out_dtype="float32"
- )
- lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape(
- lv7, R.shape([32, 8, 128, 128])
- )
- lv9: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.nn.softmax(lv8, axis=-1)
- lv10: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal(
- lv8, R.const(float("-inf"), "float32")
- )
- lv11: R.Tensor((32, 8, 128, 128), dtype="bool") =
R.logical_not(lv10)
- lv12: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max(
- lv11, axis=[-1], keepdims=True
- )
- lv13: R.Tensor((32, 8, 128, 1), dtype="bool") =
R.logical_not(lv12)
- lv14: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.full_like(
- lv9, R.const(0, "int32"), dtype="void"
- )
- lv15: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.where(lv13, lv14, lv9)
- lv16: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.broadcast_to(
- lv15, R.shape([32, 8, 128, 128])
- )
- lv17: R.Tensor((256, 128, 128), dtype="float32") = R.reshape(
- lv16, R.shape([256, 128, 128])
- )
- lv18: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.broadcast_to(
- v, R.shape([32, 8, 128, 64])
- )
- lv19: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
- lv18, R.shape([256, 128, 64])
- )
- lv20: R.Tensor((256, 128, 64), dtype="float32") = R.matmul(
- lv17, lv19, out_dtype="float32"
- )
- lv21: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape(
- lv20, R.shape([32, 8, 128, 64])
- )
- lv22: R.Tensor((128, 32, 8, 64), dtype="float32") =
R.permute_dims(
- lv21, axes=[2, 0, 1, 3]
- )
- lv23: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.permute_dims(
- lv22, axes=[1, 2, 0, 3]
- )
- gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) =
(lv23,)
+ gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) =
(lv4,)
R.output(gv)
return gv
@@ -4200,72 +4153,22 @@ def test_scaled_dot_product_attention():
mask: R.Tensor((32, 8, 128, 128), dtype="float32"),
) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply(
- q, R.const(0.35355338454246521, "float32")
- )
- lv1: R.Tensor((32, 8, 64, 128), dtype="float32") =
R.permute_dims(
- k, axes=[0, 1, 3, 2]
- )
- lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply(
- lv1, R.const(0.35355338454246521, "float32")
- )
- lv3: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.broadcast_to(
- lv, R.shape([32, 8, 128, 64])
- )
- lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
- lv3, R.shape([256, 128, 64])
- )
- lv5: R.Tensor((32, 8, 64, 128), dtype="float32") =
R.broadcast_to(
- lv2, R.shape([32, 8, 64, 128])
- )
- lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape(
- lv5, R.shape([256, 64, 128])
- )
- lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul(
- lv4, lv6, out_dtype="float32"
- )
- lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape(
- lv7, R.shape([32, 8, 128, 128])
+ lv: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ q, axes=[0, 2, 1, 3]
)
- lv9: R.Tensor((32, 8, 128, 128), dtype="float32") = R.add(lv8,
mask)
- lv10: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.nn.softmax(lv9, axis=-1)
- lv11: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal(
- lv9, R.const(float("-inf"), "float32")
+ lv1: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ k, axes=[0, 2, 1, 3]
)
- lv12: R.Tensor((32, 8, 128, 128), dtype="bool") =
R.logical_not(lv11)
- lv13: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max(
- lv12, axis=[-1], keepdims=True
+ lv2: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ v, axes=[0, 2, 1, 3]
)
- lv14: R.Tensor((32, 8, 128, 1), dtype="bool") =
R.logical_not(lv13)
- lv15: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.full_like(
- lv10, R.const(0, "int32"), dtype="void"
+ lv3: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.nn.attention_bias(
+ lv, lv1, lv2, mask, scale=None, causal_mask=None,
window_size=None
)
- lv16: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.where(lv14, lv15, lv10)
- lv17: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.broadcast_to(
- lv16, R.shape([32, 8, 128, 128])
+ lv4: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.permute_dims(
+ lv3, axes=[0, 2, 1, 3]
)
- lv18: R.Tensor((256, 128, 128), dtype="float32") = R.reshape(
- lv17, R.shape([256, 128, 128])
- )
- lv19: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.broadcast_to(
- v, R.shape([32, 8, 128, 64])
- )
- lv20: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
- lv19, R.shape([256, 128, 64])
- )
- lv21: R.Tensor((256, 128, 64), dtype="float32") = R.matmul(
- lv18, lv20, out_dtype="float32"
- )
- lv22: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape(
- lv21, R.shape([32, 8, 128, 64])
- )
- lv23: R.Tensor((128, 32, 8, 64), dtype="float32") =
R.permute_dims(
- lv22, axes=[2, 0, 1, 3]
- )
- lv24: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.permute_dims(
- lv23, axes=[1, 2, 0, 3]
- )
- gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) =
(lv24,)
+ gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) =
(lv4,)
R.output(gv)
return gv
@@ -4278,7 +4181,7 @@ def test_scaled_dot_product_attention():
),
{},
Expected1,
- run_ep_decomposition=True,
+ run_ep_decomposition=False,
)
verify_model(
@@ -4291,7 +4194,7 @@ def test_scaled_dot_product_attention():
),
{},
Expected2,
- run_ep_decomposition=True,
+ run_ep_decomposition=False,
)
# Test 2D input (seq_len, head_dim) - bug fix for #18441
@@ -7307,6 +7210,29 @@ def test_take():
verify_model(Take(), example_args, {}, Expected)
+def test_any():
+ class AnyAten(torch.nn.Module):
+ def forward(self, x):
+ return torch.ops.aten.any(x, dim=1)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="bool"),
+ ) -> R.Tuple(R.Tensor((2,), dtype="bool")):
+ with R.dataflow():
+ lv: R.Tensor((2, 3), dtype="int8") = relax.op.astype(x,
dtype="int8")
+ lv2: R.Tensor((2,), dtype="int8") = relax.op.max(lv, axis=1,
keepdims=False)
+ lv3: R.Tensor((2,), dtype="bool") = relax.op.astype(lv2,
dtype="bool")
+ gv: R.Tuple(R.Tensor((2,), dtype="bool")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.tensor([[0, 0, 0], [0, 1, 0]], dtype=torch.bool),)
+ verify_model(AnyAten(), example_args, {}, Expected)
+
+
def test_std():
class Std(Module):
def forward(self, x):