This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch tir-bench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit bad4d0a932dd8557d7b8a632bf1fac3e4539d503
Author: Bohan Hou <[email protected]>
AuthorDate: Sun May 24 09:52:18 2026 -0700

    feat(tirx): add typed pointer byte-offset intrinsic (#641)
---
 include/tvm/tirx/builtin.h           |  9 +++++++++
 python/tvm/tirx/__init__.py          |  3 ++-
 python/tvm/tirx/op.py                | 11 +++++++++++
 python/tvm/tirx/script/builder/ir.py |  2 ++
 src/target/source/codegen_c.cc       | 33 ++++++++++++++++++++++++++++++++-
 src/target/source/codegen_c.h        | 10 ++++++++++
 src/tirx/op/builtin.cc               |  4 ++++
 src/tirx/op/op.cc                    |  9 +++++++++
 8 files changed, 79 insertions(+), 2 deletions(-)

diff --git a/include/tvm/tirx/builtin.h b/include/tvm/tirx/builtin.h
index 8627d55574..ff61386699 100644
--- a/include/tvm/tirx/builtin.h
+++ b/include/tvm/tirx/builtin.h
@@ -273,6 +273,15 @@ TVM_DLL const Op& prefetch();
  */
 TVM_DLL const Op& tvm_access_ptr();
 
+/*!
+ * \brief Cast a handle to a typed pointer after adding a byte offset.
+ *
+ *  DType* ptr_byte_offset(void* data, int byte_offset, Expr dtype) {
+ *    return reinterpret_cast<DType*>(reinterpret_cast<char*>(data) + 
byte_offset);
+ *  }
+ */
+TVM_DLL const Op& ptr_byte_offset();
+
 /*!
  * \brief Create a function local static handle that iniitalizes to nullptr.
  *  can be used to cache function local static resources.
diff --git a/python/tvm/tirx/__init__.py b/python/tvm/tirx/__init__.py
index 10de65a564..efda655066 100644
--- a/python/tvm/tirx/__init__.py
+++ b/python/tvm/tirx/__init__.py
@@ -55,7 +55,8 @@ from .op import tvm_stack_alloca, tvm_stack_make_shape, 
tvm_stack_make_array
 from .op import tvm_tuple, handle_add_byte_offset, tvm_struct_get, 
tvm_struct_set
 from .op import address_of, lookup_param, assume, undef
 from .op import continue_loop, break_loop
-from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, 
tvm_throw_last_error
+from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, 
ptr_byte_offset
+from .op import tvm_throw_last_error
 from .op import (
     tvm_load_matrix_sync,
     tvm_store_matrix_sync,
diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py
index 39276b1e4c..ddf64d2e9c 100644
--- a/python/tvm/tirx/op.py
+++ b/python/tvm/tirx/op.py
@@ -876,6 +876,17 @@ def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
     return call_intrin("handle", "tirx.tvm_access_ptr", ptype, data, offset, 
extent, rw_mask)
 
 
+def ptr_byte_offset(data, byte_offset, dtype):
+    """Cast ``data + byte_offset`` to ``dtype*``.
+
+    ``byte_offset`` is always in bytes.  Use this when the source CUDA shape
+    needs an explicitly typed local pointer derived from a byte-addressed base.
+    """
+    if isinstance(dtype, str):
+        dtype = type_annotation(dtype)
+    return call_intrin("handle", "tirx.ptr_byte_offset", data, byte_offset, 
dtype)
+
+
 def tvm_throw_last_error():
     """Throw TVMGetLastError()
 
diff --git a/python/tvm/tirx/script/builder/ir.py 
b/python/tvm/tirx/script/builder/ir.py
index 8adc802cb6..da24e71a7d 100644
--- a/python/tvm/tirx/script/builder/ir.py
+++ b/python/tvm/tirx/script/builder/ir.py
@@ -3558,6 +3558,7 @@ trunc = _op_wrapper(_tir_op.trunc)
 truncdiv = _op_wrapper(_tir_op.truncdiv)
 truncmod = _op_wrapper(_tir_op.truncmod)
 tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+ptr_byte_offset = _op_wrapper(_tir_op.ptr_byte_offset)
 tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
 tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
 tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
@@ -3882,6 +3883,7 @@ __all__ = [
     "truncdiv",
     "truncmod",
     "tvm_access_ptr",
+    "ptr_byte_offset",
     "tvm_throw_last_error",
     "tvm_stack_alloca",
     "tvm_stack_make_shape",
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index fb3d5f5f38..421cde706f 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -30,6 +30,7 @@
 #include <iomanip>
 
 #include "../../arith/pattern_match.h"
+#include "../../tirx/ir/buffer_common.h"
 #include "codegen_params.h"
 
 namespace tvm {
@@ -42,6 +43,7 @@ void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = 
output_ssa; }
 void CodeGenC::InitFuncState(const PrimFunc& f) {
   alloc_storage_scope_.clear();
   handle_data_type_.clear();
+  pointer_offset_vars_.clear();
   CodeGenSourceBase::ClearFuncState();
   ReserveKeywordsAsUnique();
 }
@@ -395,6 +397,16 @@ void CodeGenC::RegisterHandleType(const VarNode* buf_var, 
DataType t) {
   }
 }
 
+void CodeGenC::RegisterHandleTypeFromPointer(const tirx::Var& var, const 
PrimExpr* value) {
+  if (value == nullptr) return;
+  auto* call = value->as<tirx::CallNode>();
+  if (call == nullptr || !call->op.same_as(builtin::ptr_byte_offset())) return;
+  std::optional<DataType> value_dtype = tirx::GetPointerType(GetType(*value));
+  if (!value_dtype.has_value()) return;
+  RegisterHandleType(var.get(), value_dtype.value());
+  pointer_offset_vars_.insert(var.get());
+}
+
 void CodeGenC::PrintVecElemLoad(const std::string& vec, DataType t, int i,
                                 std::ostream& os) {  // NOLINT(*)
   os << vec << ".s" << std::hex << i << std::dec;
@@ -708,7 +720,15 @@ void CodeGenC::VisitExpr_(const CallNode* op, 
std::ostream& os) {  // NOLINT(*)
       if (load) {
         TVM_FFI_ICHECK_EQ(load->indices.size(), 1)
             << "CodeGenC only supports flat memory allocations.";
-        os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), 
load->indices[0]) << "))";
+        const VarNode* data = load->buffer->data.get();
+        if (pointer_offset_vars_.count(data) && HandleTypeMatch(data, 
load->buffer->dtype) &&
+            !IsVolatile(data)) {
+          os << "(" << GetVarID(data) << " + ";
+          this->PrintExpr(load->indices[0], os);
+          os << ")";
+        } else {
+          os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), 
load->indices[0]) << "))";
+        }
       } else {
         auto* var = op->args[0].as<tirx::VarNode>();
         TVM_FFI_ICHECK(var)
@@ -738,6 +758,15 @@ void CodeGenC::VisitExpr_(const CallNode* op, 
std::ostream& os) {  // NOLINT(*)
       os << "(";
       this->PrintExpr(op->args[0], os);
       os << " == NULL)";
+    } else if (op->op.same_as(builtin::ptr_byte_offset())) {
+      TVM_FFI_ICHECK_EQ(op->args.size(), 3U);
+      os << "((";
+      PrintType(op->args[2].dtype(), os);
+      os << "*)(((char*)";
+      this->PrintExpr(op->args[0], os);
+      os << ") + ";
+      this->PrintExpr(op->args[1], os);
+      os << "))";
     } else if (op->op.same_as(builtin::handle_add_byte_offset())) {
       TVM_FFI_ICHECK_EQ(op->args.size(), 2U);
       os << "((void*)((char*)";
@@ -953,6 +982,7 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& 
os) {  // NOLINT(*)
   } else {
     let_binding_[op->var] = op;
   }
+  RegisterHandleTypeFromPointer(op->var, &op->value);
   std::string value = PrintExpr(op->value);
   if (print_ssa_form_) {
     TVM_FFI_ICHECK(!var_idmap_.count(op->var.get()));
@@ -1077,6 +1107,7 @@ void CodeGenC::VisitExpr_(const SelectNode* op, 
std::ostream& os) {  // NOLINT(*
 }
 
 void CodeGenC::VisitStmt_(const BindNode* op) {
+  RegisterHandleTypeFromPointer(op->var, &op->value);
   std::string value = PrintExpr(op->value);
   if (print_ssa_form_) {
     TVM_FFI_ICHECK(!var_idmap_.count(op->var.get()));
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index f1d04bf4aa..b044f3f3a4 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -302,6 +302,14 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, 
std::ostream&)>,
    * \param t The type to be checked.
    */
   void RegisterHandleType(const VarNode* buf_var, DataType t);
+  /*!
+   * \brief Register a typed pointer produced by explicit pointer-offset 
intrinsics.
+   *
+   * Ordinary handle lets remain void* so generic buffer views do not change
+   * code shape.  Only explicit pointer-offset values opt into typed pointer
+   * arithmetic.
+   */
+  void RegisterHandleTypeFromPointer(const tirx::Var& var, const PrimExpr* 
value);
   // override
   void PrintSSAAssign(const std::string& target, const std::string& src, 
DataType t) override;
   /*! \brief reserves common C keywords */
@@ -318,6 +326,8 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, 
std::ostream&)>,
   std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
   /*! \brief the data type of allocated buffers */
   std::unordered_map<const VarNode*, DataType> handle_data_type_;
+  /*! \brief Handle vars whose address_of(buffer[index]) should print as ptr + 
index. */
+  std::unordered_set<const VarNode*> pointer_offset_vars_;
   /*! \brief Record of ops that have pre-defined global symbol. */
   OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = 
Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
   // cache commonly used ops
diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc
index 40cbddef96..2f34efe7ab 100644
--- a/src/tirx/op/builtin.cc
+++ b/src/tirx/op/builtin.cc
@@ -176,6 +176,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr)
     .set_num_inputs(5)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kSpecialCallArg));
 
+TIR_DEFINE_BUILTIN_FUNC(ptr_byte_offset)
+    .set_num_inputs(3)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
+
 TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle)
     .set_num_inputs(0)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kSpecialCallArg));
diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc
index c2772ad69f..a6f00bc09d 100644
--- a/src/tirx/op/op.cc
+++ b/src/tirx/op/op.cc
@@ -86,6 +86,15 @@ Type GetType(const PrimExpr& expr) {
           << "to be a type annotation, but found " << type_annotation->op;
       return PointerType(PrimType(type_annotation->dtype));
     }
+    if (access->op.same_as(builtin::ptr_byte_offset())) {
+      TVM_FFI_ICHECK_EQ(access->args.size(), 3U);
+      auto type_annotation = Downcast<Call>(access->args[2]);
+      static auto builtin_op = Op::Get("tirx.type_annotation");
+      TVM_FFI_ICHECK(type_annotation->op.same_as(builtin_op))
+          << "Expected the third argument of builtin ptr_byte_offset() "
+          << "to be a type annotation, but found " << type_annotation->op;
+      return PointerType(PrimType(type_annotation->dtype));
+    }
   }
 
   if (auto* address_of = expr.as<tirx::CallNode>()) {

Reply via email to