This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 30bf568dd0 [Tests] Check WebGPU volatile allreduce annotation
structurally (#19740)
30bf568dd0 is described below
commit 30bf568dd0d1a61e622ac84dda49486292577c92
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Jun 12 00:38:19 2026 -0400
[Tests] Check WebGPU volatile allreduce annotation structurally (#19740)
This pr updates the WebGPU multi-warp allreduce test to check the
generated `tirx.volatile` allocation annotation structurally instead of
matching the exact TVMScript printer output.
The test is intended to verify that `LowerThreadAllreduce` marks the
generated shared allocation as volatile. It previously checked for the
exact string:
```python
"tirx.volatile": T.bool(True)
```
However, the current printer emits the same annotation as:
```python
annotations={"tirx.volatile": True}
```
The transform behavior is unchanged; only the printer spelling differs.
This patch walks the generated TIRX body and checks for an `AllocBuffer`
with `tirx.volatile=True`, which matches the actual semantic requirement
of the test without depending on bool literal formatting.
---
.../test_s_tir_transform_lower_thread_all_reduce.py | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
index f39ccb6fde..b719416e62 100644
---
a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
+++
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
@@ -23,6 +23,18 @@ from tvm.script import ir as I
from tvm.script import tirx as T
+def _has_volatile_alloc_buffer(mod):
+ has_volatile_alloc = False
+
+ def visit(node):
+ nonlocal has_volatile_alloc
+ if isinstance(node, tvm.tirx.AllocBuffer) and "tirx.volatile" in
node.annotations:
+ has_volatile_alloc = has_volatile_alloc or
node.annotations["tirx.volatile"] is True
+
+ tvm.tirx.stmt_functor.post_order_visit(mod["main"].body, visit)
+ return has_volatile_alloc
+
+
def test_basic():
transform = tvm.s_tir.transform.LowerThreadAllreduce()
@@ -503,7 +515,7 @@ def test_webgpu_multi_warp_reduce():
After_script = After.script()
assert "tvm_warp_shuffle_down" in After_script
assert "tvm_storage_sync" in After_script
- assert '"tirx.volatile": T.bool(True)' in After_script
+ assert _has_volatile_alloc_buffer(After)
assert "T.uint32(" not in After_script