This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f09e61bd86 [TIR] Extend address_of to support Buffer objects (#18068)
f09e61bd86 is described below
commit f09e61bd86c44e4693e88fb25c99149dabf61ebd
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Jun 17 20:24:50 2025 +0800
[TIR] Extend address_of to support Buffer objects (#18068)
This commit enhances the address_of function to accept both Buffer and
BufferLoad objects. When a Buffer is passed, it automatically creates
a BufferLoad with zero indices to get the base address.
---
python/tvm/tir/op.py | 18 +++++++++++++-----
tests/python/tvmscript/test_tvmscript_roundtrip.py | 9 +++++++++
2 files changed, 22 insertions(+), 5 deletions(-)
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 57aa060cd0..97eb3d5b2f 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -26,7 +26,7 @@ from tvm.runtime import const
from . import _ffi_api
from .buffer import Buffer
-from .expr import Call, CommReducer, IntImm, PrimExprWithOp, Var
+from .expr import BufferLoad, Call, CommReducer, IntImm, PrimExprWithOp, Var
def _pack_buffer(buf, span=None):
@@ -553,13 +553,13 @@ def tvm_struct_set(arr, index, field, value):
return call_intrin("int32", "tir.tvm_struct_set", arr, index, field, value)
-def address_of(buffer_load, span=None):
+def address_of(obj: Union[Buffer, BufferLoad], span: Optional[Span] = None) ->
PrimExpr:
"""Returns the address of an element in the buffer
Parameters
----------
- buffer_load: BufferLoad
- The buffer load.
+ obj: Union[Buffer, BufferLoad]
+ The buffer or buffer load.
span : Optional[Span]
The location of this operator in the source code.
@@ -569,7 +569,15 @@ def address_of(buffer_load, span=None):
call : PrimExpr
The call expression.
"""
- return call_intrin("handle", "tir.address_of", buffer_load, span=span)
+ if isinstance(obj, Buffer):
+
+ n_dim = len(obj.shape)
+ buffer_load = BufferLoad(obj, [0] * n_dim)
+ return call_intrin("handle", "tir.address_of", buffer_load, span=span)
+ elif isinstance(obj, BufferLoad):
+ return call_intrin("handle", "tir.address_of", obj, span=span)
+ else:
+ raise ValueError(f"Invalid object type: {type(obj)}")
def lookup_param(param_name, span=None):
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index cebfb842ba..af2db34415 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -4266,5 +4266,14 @@ def test_return_none_no_trailing_type():
assert "-> None" not in script
+def test_address_of_buffer():
+ @T.prim_func
+ def func(a: T.handle):
+ A = T.match_buffer(a, (128, 128), "float32")
+ T.evaluate(T.address_of(A))
+
+ assert "T.address_of(A[0, 0])" in func.script()
+
+
if __name__ == "__main__":
tvm.testing.main()