This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f09e61bd86 [TIR] Extend address_of to support Buffer objects (#18068)
f09e61bd86 is described below

commit f09e61bd86c44e4693e88fb25c99149dabf61ebd
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Jun 17 20:24:50 2025 +0800

    [TIR] Extend address_of to support Buffer objects (#18068)
    
    This commit enhances the address_of function to accept both Buffer and
    BufferLoad objects. When a Buffer is passed, it automatically creates
    a BufferLoad with zero indices to get the base address.
---
 python/tvm/tir/op.py                               | 18 +++++++++++++-----
 tests/python/tvmscript/test_tvmscript_roundtrip.py |  9 +++++++++
 2 files changed, 22 insertions(+), 5 deletions(-)

diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 57aa060cd0..97eb3d5b2f 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -26,7 +26,7 @@ from tvm.runtime import const
 
 from . import _ffi_api
 from .buffer import Buffer
-from .expr import Call, CommReducer, IntImm, PrimExprWithOp, Var
+from .expr import BufferLoad, Call, CommReducer, IntImm, PrimExprWithOp, Var
 
 
 def _pack_buffer(buf, span=None):
@@ -553,13 +553,13 @@ def tvm_struct_set(arr, index, field, value):
     return call_intrin("int32", "tir.tvm_struct_set", arr, index, field, value)
 
 
-def address_of(buffer_load, span=None):
+def address_of(obj: Union[Buffer, BufferLoad], span: Optional[Span] = None) -> 
PrimExpr:
     """Returns the address of an element in the buffer
 
     Parameters
     ----------
-    buffer_load: BufferLoad
-        The buffer load.
+    obj: Union[Buffer, BufferLoad]
+        The buffer or buffer load.
 
     span : Optional[Span]
         The location of this operator in the source code.
@@ -569,7 +569,15 @@ def address_of(buffer_load, span=None):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("handle", "tir.address_of", buffer_load, span=span)
+    if isinstance(obj, Buffer):
+
+        n_dim = len(obj.shape)
+        buffer_load = BufferLoad(obj, [0] * n_dim)
+        return call_intrin("handle", "tir.address_of", buffer_load, span=span)
+    elif isinstance(obj, BufferLoad):
+        return call_intrin("handle", "tir.address_of", obj, span=span)
+    else:
+        raise ValueError(f"Invalid object type: {type(obj)}")
 
 
 def lookup_param(param_name, span=None):
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py 
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index cebfb842ba..af2db34415 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -4266,5 +4266,14 @@ def test_return_none_no_trailing_type():
     assert "-> None" not in script
 
 
+def test_address_of_buffer():
+    @T.prim_func
+    def func(a: T.handle):
+        A = T.match_buffer(a, (128, 128), "float32")
+        T.evaluate(T.address_of(A))
+
+    assert "T.address_of(A[0, 0])" in func.script()
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to