This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b6db2ec89f [Runtime] CutensorMap support (#18097)
b6db2ec89f is described below

commit b6db2ec89f6cf62562a790b3cc98404ef505c0fa
Author: Bohan Hou <[email protected]>
AuthorDate: Mon Jun 30 07:59:52 2025 -0400

    [Runtime] CutensorMap support (#18097)
    
    This PR introduces Cutensor map support in the runtime module. It enables 
calling kernels whose arguments are cuTensorMap, these arguments are passed as 
handle(address) and associated with arg_extra_tags that indicate indicate it is 
tensor map. The TensorMap is allocated on stack with a runtime API
---
 include/tvm/ir/type.h                              |  29 +++
 include/tvm/script/ir_builder/tir/ir.h             |   2 +
 python/tvm/ir/type.py                              |  16 ++
 python/tvm/script/ir_builder/tir/ir.py             |   2 +
 src/ir/type.cc                                     |  12 ++
 src/runtime/cuda/cuda_device_api.cc                | 206 +++++++++++++++++++++
 src/runtime/cuda/cuda_module.cc                    |   2 +-
 src/runtime/file_utils.cc                          |  13 ++
 src/runtime/meta_data.h                            |   3 +
 src/runtime/pack_args.h                            |  26 ++-
 src/script/ir_builder/tir/ir.cc                    |   1 +
 src/script/printer/tir/expr.cc                     |  30 +--
 src/target/build_common.h                          |  10 +
 src/target/llvm/codegen_cpu.cc                     |   4 +
 src/target/llvm/codegen_llvm.cc                    |   8 +-
 src/target/llvm/codegen_llvm.h                     |   1 +
 src/target/source/codegen_c.cc                     |  16 +-
 src/target/source/codegen_cuda.cc                  |   3 +-
 tests/python/codegen/test_target_codegen_cuda.py   |  31 ++++
 .../test_tir_transform_inject_ptx_async_copy.py    |   4 +-
 20 files changed, 396 insertions(+), 23 deletions(-)

diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index 5ca35449fc..d864766d7f 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -330,5 +330,34 @@ class FuncType : public Type {
   TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
 };
 
+/*!
+ * \brief The type of tensor map.
+ * \sa TensorMapType
+ */
+class TensorMapTypeNode : public TypeNode {
+ public:
+  void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
+
+  bool SEqualReduce(const TensorMapTypeNode* other, SEqualReducer equal) const 
{
+    return equal(span, other->span);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(span); }
+
+  static constexpr const char* _type_key = "TensorMapType";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorMapTypeNode, TypeNode);
+};
+
+/*!
+ * \brief Managed reference to TensorMapTypeNode.
+ * \sa TensorMapTypeNode
+ */
+class TensorMapType : public Type {
+ public:
+  TVM_DLL TensorMapType(Span span = Span());
+
+  TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TensorMapType, 
Type, TensorMapTypeNode);
+};
+
 }  // namespace tvm
 #endif  // TVM_IR_TYPE_H_
diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index febdac55d9..30b5bb3382 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -452,6 +452,8 @@ inline Var Handle(runtime::DataType dtype = 
runtime::DataType::Void(),
   return is_size_var ? tvm::tir::SizeVar("", type_annotation) : 
tvm::tir::Var("", type_annotation);
 }
 
+inline Var TensormapHandle() { return tvm::tir::Var("", 
PointerType(TensorMapType())); }
+
 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)                     
                \
   inline PrimExpr FuncName(Optional<PrimExpr> expr = std::nullopt, bool 
is_size_var = false) { \
     DataType dtype = DType;                                                    
                \
diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py
index 9ec8ef8fbd..d0bf7014e2 100644
--- a/python/tvm/ir/type.py
+++ b/python/tvm/ir/type.py
@@ -107,3 +107,19 @@ class FuncType(Type):
             arg_types,
             ret_type,
         )
+
+
[email protected]_object("TensorMapType")
+class TensorMapType(Type):
+    """TensorMapType used in the low-level TIR.
+
+    Parameters
+    ----------
+    span : tvm.ir.Span
+        The span information.
+    """
+
+    def __init__(self, span=None):
+        self.__init_handle_by_constructor__(
+            _ffi_api.TensorMapType, span  # pylint: disable=no-member
+        )
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index c7589f4a19..5864de2cac 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1566,6 +1566,8 @@ def handle(
     res : PrimExpr
         The new tir.Var with type handle or casted expression with type handle.
     """
+    if dtype == "tensormap":
+        return _ffi_api.TensormapHandle()  # type: ignore[attr-defined] # 
pylint: disable=no-member
     is_unknown_type = dtype is None
     if dtype is None:
         dtype = "void"
diff --git a/src/ir/type.cc b/src/ir/type.cc
index 95b65475be..cd7b6a523c 100644
--- a/src/ir/type.cc
+++ b/src/ir/type.cc
@@ -89,4 +89,16 @@ 
TVM_FFI_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array<Type> fields) {
   return TupleType(fields);
 });
 
+TVM_FFI_REGISTER_GLOBAL("ir.TensorMapType").set_body_typed([](Span span) {
+  return TensorMapType(span);
+});
+
+TensorMapType::TensorMapType(Span span) {
+  ObjectPtr<TensorMapTypeNode> n = make_object<TensorMapTypeNode>();
+  n->span = std::move(span);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(TensorMapTypeNode);
+
 }  // namespace tvm
diff --git a/src/runtime/cuda/cuda_device_api.cc 
b/src/runtime/cuda/cuda_device_api.cc
index 399312e193..98a83f4ed7 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -357,5 +357,211 @@ TVM_DLL int GetCudaDeviceCount() {
 
 
TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDeviceCount);
 
+/**
+ * \brief FFI wrapper for cuTensorMapEncodeTiled.
+ *
+ * This function registers a global function `runtime.cuTensorMapEncodeTiled` 
that can be
+ * called from other parts of the TVM runtime (e.g., Python). It wraps the 
CUDA Driver API
+ * function `cuTensorMapEncodeTiled`, which initializes a tensor map 
descriptor (CUtensorMap).
+ *
+ * \param tensor_map (handle): A `void*` pointer to the CUtensorMap object to 
be initialized.
+ * \param tensor_dtype (DataType): The TVM data type of the tensor.
+ * \param tensor_rank (int): The rank (number of dimensions) of the tensor.
+ * \param tensor_ptr (handle): A `void*` pointer to the start of the tensor in 
global memory.
+ * \param global_shape (int...): `tensor_rank` integer arguments for the 
global tensor dimensions.
+ * \param global_strides (int...): `tensor_rank - 1` integer arguments for the 
global tensor
+ * strides. The stride for the innermost dimension is not provided as it's 
assumed to be contiguous.
+ * \param shared_shape (int...): `tensor_rank` integer arguments for the shape 
of the tile (box)
+ * in shared memory.
+ * \param shared_strides (int...): `tensor_rank` integer arguments for the 
strides of the tile (box)
+ * in shared memory.
+ * \param interleaved_kind (int): An integer corresponding to the 
CUtensorMapInterleave enum.
+ * \param swizzle_kind (int): An integer corresponding to the 
CUtensorMapSwizzle enum.
+ * \param l2_promotion_kind (int): An integer corresponding to the 
CUtensorMapL2promotion enum.
+ * \param oob_fill_kind (int): An integer corresponding to the 
CUtensorMapFloatOOBfill enum.
+ */
+TVM_FFI_REGISTER_GLOBAL("runtime.cuTensorMapEncodeTiled")
+    .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
+      CHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4 
arguments";
+      size_t arg_cnt = 0;
+      CUtensorMap* tensor_map = 
static_cast<CUtensorMap*>(args[arg_cnt++].cast<void*>());
+      runtime::DataType tensor_dtype = 
args[arg_cnt++].cast<runtime::DataType>();
+      uint32_t tensor_rank = 
static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
+      void* tensor_ptr = static_cast<void*>(args[arg_cnt++].cast<void*>());
+
+      CHECK_EQ(args.size(), 4 + tensor_rank * 4 + 3)
+          << "cuTensorMapEncodeTiled expects " << 4 + tensor_rank * 4 + 3 << " 
arguments"
+          << "tensor_map, tensor_dtype, tensor_rank, tensor_ptr, 
global_shape(" << tensor_rank
+          << "), global_strides(" << tensor_rank - 1 << "), shared_shape(" << 
tensor_rank
+          << "), shared_strides(" << tensor_rank << "), interleaved_kind, 
swizzle_kind"
+          << ", l2_promotion_kind, oob_fill_kind";
+
+      std::vector<cuuint64_t> global_shape(tensor_rank);
+      std::vector<cuuint64_t> global_strides(tensor_rank);
+      std::vector<uint32_t> shared_shape(tensor_rank);
+      std::vector<uint32_t> shared_strides(tensor_rank);
+      for (size_t i = 0; i < tensor_rank; ++i) {
+        global_shape[i] = 
static_cast<cuuint64_t>(args[arg_cnt++].cast<int64_t>());
+      }
+      for (size_t i = 0; i < tensor_rank - 1; ++i) {
+        global_strides[i] = 
static_cast<cuuint64_t>(args[arg_cnt++].cast<int64_t>());
+        CHECK_EQ(global_strides[i] % 16, 0) << "global strides must be 
multiple of 16";
+      }
+      for (size_t i = 0; i < tensor_rank; ++i) {
+        shared_shape[i] = 
static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
+        CHECK_GE(shared_shape[i], 0) << "boxDim must be non-negative";
+        CHECK_LE(shared_shape[i], 256) << "boxDim must be less than or equal 
to 256";
+      }
+      for (size_t i = 0; i < tensor_rank; ++i) {
+        shared_strides[i] = 
static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
+      }
+      auto interleaved_kind = 
static_cast<CUtensorMapInterleave>(args[arg_cnt++].cast<int>());
+      auto swizzle_kind = 
static_cast<CUtensorMapSwizzle>(args[arg_cnt++].cast<int>());
+      auto l2_promotion_kind = 
static_cast<CUtensorMapL2promotion>(args[arg_cnt++].cast<int>());
+      auto oob_fill_kind = 
static_cast<CUtensorMapFloatOOBfill>(args[arg_cnt++].cast<int>());
+
+      ICHECK_EQ(tensor_dtype.lanes(), 1)
+          << "Expect tensor_dtype to have lanes=1, but get " << tensor_dtype;
+      CUtensorMapDataType cu_dtype;
+      switch (tensor_dtype.code()) {
+        case DataType::kInt:
+          // int
+          switch (tensor_dtype.bits()) {
+            case 8:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
+              break;
+            case 32:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT32;
+              break;
+            case 64:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT64;
+              break;
+            default:
+              LOG(FATAL) << "Unsupported data type " << 
runtime::DLDataTypeToString(tensor_dtype);
+          }
+          break;
+        case DataType::kUInt:
+          // unsigned int
+          switch (tensor_dtype.bits()) {
+            case 8:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
+              break;
+            case 16:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT16;
+              break;
+            case 32:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT32;
+              break;
+            case 64:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT64;
+              break;
+            default:
+              LOG(FATAL) << "Unsupported data type " << 
runtime::DLDataTypeToString(tensor_dtype);
+          }
+          break;
+        case DataType::kFloat:
+          // float
+          switch (tensor_dtype.bits()) {
+            case 16:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
+              break;
+            case 32:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
+              break;
+            case 64:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
+              break;
+            default:
+              LOG(FATAL) << "Unsupported data type " << 
runtime::DLDataTypeToString(tensor_dtype);
+          }
+          break;
+        case DataType::kBFloat:
+          // bfloat
+          switch (tensor_dtype.bits()) {
+            case 16:
+              cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
+              break;
+            default:
+              LOG(FATAL) << "Unsupported data type " << 
runtime::DLDataTypeToString(tensor_dtype);
+          }
+          break;
+        case DataType::kFloat8_e4m3fn:
+          // NV float8 e4m3
+          cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
+          break;
+        case DataType::kFloat8_e5m2:
+          // NV float8 e5m2
+          cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
+          break;
+        default:
+          LOG(FATAL) << "Unsupported data type " << 
runtime::DLDataTypeToString(tensor_dtype);
+      }
+
+      // sanity checks per cuTensorMapEncodeTiled requirements
+      // see
+      // 
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
+      CHECK_EQ((reinterpret_cast<uint64_t>(tensor_ptr) & 0b1111), 0);    // 
16-byte alignment
+      CHECK_EQ((reinterpret_cast<uint64_t>(tensor_map) & 0b111111), 0);  // 
64-byte alignment
+      CHECK_LE(tensor_rank, 5) << "cuTensorMapEncodeTiled only supports up to 
5D tensors";
+
+      if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) {
+        CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 32)
+            << "CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner 
dimension will be <= 32.";
+      } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) {
+        CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 64)
+            << "CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner 
dimension will be <= 64.";
+      } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_128B) {
+        CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 128)
+            << "CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner 
dimension will be <= "
+               "128.";
+      }
+
+      const cuuint64_t* global_shape_ptr = global_shape.data();
+      const cuuint64_t* global_strides_ptr = global_strides.data();
+      const uint32_t* shared_shape_ptr = shared_shape.data();
+      const uint32_t* shared_strides_ptr = shared_strides.data();
+
+      CUresult res =
+          cuTensorMapEncodeTiled(tensor_map, cu_dtype, tensor_rank, 
tensor_ptr, global_shape_ptr,
+                                 global_strides_ptr, shared_shape_ptr, 
shared_strides_ptr,
+                                 interleaved_kind, swizzle_kind, 
l2_promotion_kind, oob_fill_kind);
+      const char* errstr;
+      cuGetErrorString(res, &errstr);
+      if (res != CUDA_SUCCESS) {
+        // get error string
+        const char* error_string = nullptr;
+        cuGetErrorString(res, &error_string);
+        std::cerr << "Error in cuTensorMapEncodeTiled: " << error_string << 
std::endl;
+        std::cout << "cu_dtype: " << cu_dtype << "\n";
+        std::cout << "TMA Desc Addr:   " << tensor_map << "\n";
+        std::cout << "TMA Interleave:  " << interleaved_kind << "\n";
+        std::cout << "TMA L2Promotion: " << l2_promotion_kind << "\n";
+        std::cout << "TMA OOBFill:     " << oob_fill_kind << "\n";
+        std::cout << "SMEM Swizzle:    " << swizzle_kind << "\n";
+        std::cout << "tensor rank: " << tensor_rank << "\n";
+        std::cout << "global prob shape: ";
+        for (size_t i = 0; i < tensor_rank; i++) {
+          std::cout << global_shape[i] << " ";
+        }
+        std::cout << "\n";
+        std::cout << "global prob stride: ";
+        for (size_t i = 0; i < tensor_rank; i++) {
+          std::cout << global_strides[i] << " ";
+        }
+        std::cout << "\n";
+        std::cout << "smem box shape: ";
+        for (size_t i = 0; i < tensor_rank; i++) {
+          std::cout << shared_shape[i] << " ";
+        }
+        std::cout << "\n";
+        std::cout << "smem box stride: ";
+        for (size_t i = 0; i < tensor_rank; i++) {
+          std::cout << shared_strides[i] << " ";
+        }
+        std::cout << "\n";
+        CHECK_EQ(res, CUDA_SUCCESS) << "Error in cuTensorMapEncodeTiled: " << 
errstr;
+      }
+    });
+
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index a29d303acf..6d69fde5cd 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -266,7 +266,7 @@ ffi::Function CUDAModuleNode::GetFunction(const String& 
name,
   const FunctionInfo& info = it->second;
   CUDAWrappedFunc f;
   f.Init(this, sptr_to_self, name, info.arg_types.size(), 
info.launch_param_tags);
-  return PackFuncVoidAddr(f, info.arg_types);
+  return PackFuncVoidAddr(f, info.arg_types, info.arg_extra_tags);
 }
 
 Module CUDAModuleCreate(std::string data, std::string fmt,
diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc
index 2aa377b9f8..513efbd9fb 100644
--- a/src/runtime/file_utils.cc
+++ b/src/runtime/file_utils.cc
@@ -45,6 +45,11 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
   writer->WriteObjectKeyValue("name", name);
   writer->WriteObjectKeyValue("arg_types", sarg_types);
   writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags);
+  std::vector<int> iarg_extra_tags(arg_extra_tags.size());
+  for (size_t i = 0; i < arg_extra_tags.size(); ++i) {
+    iarg_extra_tags[i] = static_cast<int>(arg_extra_tags[i]);
+  }
+  writer->WriteObjectKeyValue("arg_extra_tags", iarg_extra_tags);
   writer->EndObject();
 }
 
@@ -56,6 +61,12 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
   helper.DeclareOptionalField("launch_param_tags", &launch_param_tags);
   helper.DeclareOptionalField("thread_axis_tags",
                               &launch_param_tags);  // for backward 
compatibility
+  std::vector<int> iarg_extra_tags;
+  helper.DeclareOptionalField("arg_extra_tags", &iarg_extra_tags);
+  arg_extra_tags.resize(iarg_extra_tags.size());
+  for (size_t i = 0; i < arg_extra_tags.size(); ++i) {
+    arg_extra_tags[i] = static_cast<ArgExtraTags>(iarg_extra_tags[i]);
+  }
   helper.ReadAllFields(reader);
   arg_types.resize(sarg_types.size());
   for (size_t i = 0; i < arg_types.size(); ++i) {
@@ -67,12 +78,14 @@ void FunctionInfo::Save(dmlc::Stream* writer) const {
   writer->Write(name);
   writer->Write(arg_types);
   writer->Write(launch_param_tags);
+  writer->Write(arg_extra_tags);
 }
 
 bool FunctionInfo::Load(dmlc::Stream* reader) {
   if (!reader->Read(&name)) return false;
   if (!reader->Read(&arg_types)) return false;
   if (!reader->Read(&launch_param_tags)) return false;
+  if (!reader->Read(&arg_extra_tags)) return false;
   return true;
 }
 
diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h
index 51120c1f9e..8acefecaad 100644
--- a/src/runtime/meta_data.h
+++ b/src/runtime/meta_data.h
@@ -59,6 +59,9 @@ struct FunctionInfo {
   std::vector<DLDataType> arg_types;
   std::vector<std::string> launch_param_tags;
 
+  enum class ArgExtraTags : int { kNone = 0, kTensorMap = 1 };
+  std::vector<ArgExtraTags> arg_extra_tags;
+
   void Save(dmlc::JSONWriter* writer) const;
   void Load(dmlc::JSONReader* reader);
   void Save(dmlc::Stream* writer) const;
diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h
index 0068db51d5..8929f90b0f 100644
--- a/src/runtime/pack_args.h
+++ b/src/runtime/pack_args.h
@@ -64,12 +64,15 @@ union ArgUnion64 {
  *
  * \param f with signiture (ffi::PackedArgs args, ffi::Any* rv, void* 
void_args)
  * \param arg_types The arguments type information.
+ * \param arg_extra_tags extra tags for the arguments
  * \tparam F the function type
  *
  * \return The wrapped packed function.
  */
 template <typename F>
-inline ffi::Function PackFuncVoidAddr(F f, const std::vector<DLDataType>& 
arg_types);
+inline ffi::Function PackFuncVoidAddr(
+    F f, const std::vector<DLDataType>& arg_types,
+    const std::vector<FunctionInfo::ArgExtraTags>& arg_extra_tags = {});
 /*!
  * \brief Create a packed function that from function only packs buffer 
arguments.
  *
@@ -130,7 +133,8 @@ enum ArgConvertCode {
   INT64_TO_UINT32,
   FLOAT64_TO_FLOAT32,
   FLOAT64_TO_FLOAT64,
-  HANDLE_TO_HANDLE
+  HANDLE_TO_HANDLE,
+  HANDLE_TO_TENSORMAP
 };
 
 inline ArgConvertCode GetArgConvertCode(DLDataType t) {
@@ -183,6 +187,10 @@ inline ffi::Function PackFuncVoidAddr_(F f, const 
std::vector<ArgConvertCode>& c
           addr[i] = &(holder[i]);
           break;
         }
+        case HANDLE_TO_TENSORMAP: {
+          addr[i] = raw_args[i].v_ptr;
+          break;
+        }
       }
     }
     f(args, ret, addr);
@@ -222,7 +230,8 @@ inline ffi::Function PackFuncNonBufferArg_(F f, int base,
           holder[i].v_float32[0] = static_cast<float>(raw_args[base + 
i].v_float64);
           break;
         }
-        case HANDLE_TO_HANDLE: {
+        case HANDLE_TO_HANDLE:
+        case HANDLE_TO_TENSORMAP: {
           LOG(FATAL) << "not reached";
           break;
         }
@@ -284,6 +293,7 @@ inline ffi::Function PackFuncPackedArgAligned_(F f, const 
std::vector<ArgConvert
           ++ptr;
           break;
         }
+        case HANDLE_TO_TENSORMAP:
         default: {
           LOG(FATAL) << "not reached";
           break;
@@ -297,10 +307,16 @@ inline ffi::Function PackFuncPackedArgAligned_(F f, const 
std::vector<ArgConvert
 }  // namespace detail
 
 template <typename F>
-inline ffi::Function PackFuncVoidAddr(F f, const std::vector<DLDataType>& 
arg_types) {
+inline ffi::Function PackFuncVoidAddr(
+    F f, const std::vector<DLDataType>& arg_types,
+    const std::vector<FunctionInfo::ArgExtraTags>& arg_extra_tags) {
   std::vector<detail::ArgConvertCode> codes(arg_types.size());
   for (size_t i = 0; i < arg_types.size(); ++i) {
-    codes[i] = detail::GetArgConvertCode(arg_types[i]);
+    if (arg_extra_tags.size() > i && arg_extra_tags[i] == 
FunctionInfo::ArgExtraTags::kTensorMap) {
+      codes[i] = detail::HANDLE_TO_TENSORMAP;
+    } else {
+      codes[i] = detail::GetArgConvertCode(arg_types[i]);
+    }
   }
   size_t num_void_args = arg_types.size();
   // specialization
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 79e14feee4..7ef970fa09 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -830,6 +830,7 @@ 
TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN
 
 
TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
 TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
+TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.TensormapHandle").set_body_typed(TensormapHandle);
 TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
 
 TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.min")
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index 549247449e..d2f02f7908 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -37,19 +37,23 @@ ExprDoc PrintVarCreation(const tir::Var& var, const 
ObjectPath& var_p, const IRD
   }
 
   if (const auto* ptr_type = type.as<PointerTypeNode>()) {
-    const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
-    ICHECK(prim_type);
-    ExprDoc element_type =
-        LiteralDoc::DataType(prim_type->dtype, 
type_p->Attr("element_type")->Attr("dtype"));
-    rhs = TIR(d, "handle");
-    rhs->source_paths.push_back(var_p->Attr("dtype"));
-    if (ptr_type->storage_scope == "") {
-      rhs = rhs->Call({element_type}, kwargs_keys, kwargs_values);
-    } else {
-      rhs = rhs->Call({element_type,
-                       LiteralDoc::Str(ptr_type->storage_scope,  //
-                                       type_p->Attr("storage_scope"))},
-                      kwargs_keys, kwargs_values);
+    if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
+      ExprDoc element_type =
+          LiteralDoc::DataType(prim_type->dtype, 
type_p->Attr("element_type")->Attr("dtype"));
+      rhs = TIR(d, "handle");
+      rhs->source_paths.push_back(var_p->Attr("dtype"));
+      if (ptr_type->storage_scope == "") {
+        rhs = rhs->Call({element_type}, kwargs_keys, kwargs_values);
+      } else {
+        rhs = rhs->Call({element_type,
+                         LiteralDoc::Str(ptr_type->storage_scope,  //
+                                         type_p->Attr("storage_scope"))},
+                        kwargs_keys, kwargs_values);
+      }
+    } else if (ptr_type->element_type->IsInstance<TensorMapTypeNode>()) {
+      rhs = TIR(d, "handle")
+                ->Call({LiteralDoc::Str("tensormap", 
type_p->Attr("element_type")->Attr("dtype"))},
+                       {}, {});
     }
   } else {
     rhs = TIR(d, DType2Str(var->dtype));
diff --git a/src/target/build_common.h b/src/target/build_common.h
index 70f15d091e..fda7e2e67c 100644
--- a/src/target/build_common.h
+++ b/src/target/build_common.h
@@ -49,6 +49,16 @@ inline std::unordered_map<std::string, 
runtime::FunctionInfo> ExtractFuncInfo(co
     runtime::FunctionInfo info;
     for (size_t i = 0; i < f->params.size(); ++i) {
       info.arg_types.push_back(f->params[i].dtype());
+      auto is_tensormap = [](const tir::Var& var) -> bool {
+        const auto* type = var->type_annotation.as<PointerTypeNode>();
+        if (type == nullptr) {
+          return false;
+        }
+        return type->element_type.as<TensorMapTypeNode>() != nullptr;
+      };
+      info.arg_extra_tags.push_back(is_tensormap(f->params[i])
+                                        ? 
runtime::FunctionInfo::ArgExtraTags::kTensorMap
+                                        : 
runtime::FunctionInfo::ArgExtraTags::kNone);
     }
     if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
       for (const auto& tag : opt.value()) {
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 76e825ab75..b16617e3d6 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -1037,6 +1037,10 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* 
op) {
         return builder_->CreateAlloca(t_tvm_ffi_any_, num);
       } else if (type == "array") {
         return builder_->CreateAlloca(t_tvm_array_, num);
+      } else if (type == "tensormap") {
+        auto* alloca = builder_->CreateAlloca(t_tvm_tensormap_, num);
+        alloca->setAlignment(llvm::Align(64));
+        return alloca;
       } else {
         LOG(FATAL) << "Unknown stack alloca type " << type;
       }
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index e9bcfa97fd..45dafa85b9 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -154,6 +154,8 @@ void CodeGenLLVM::Init(const std::string& module_name, 
LLVMTarget* llvm_target,
   t_int32_ = llvm::Type::getInt32Ty(*ctx);
   t_int64_ = llvm::Type::getInt64Ty(*ctx);
   t_float64_ = llvm::Type::getDoubleTy(*ctx);
+  // CUTensorMap is a 128 byte struct, so we use a 128 byte array to represent 
it.
+  t_tvm_tensormap_ = llvm::ArrayType::get(t_char_, 128);
   // meta data
   md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1);
   md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
@@ -620,11 +622,15 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) 
const {
       if (primtype->dtype.is_void() || primtype->dtype.code() >= 
DataType::kCustomBegin) {
         return t_void_p_;
       }
+    } else if (ptr->element_type->IsInstance<TensorMapTypeNode>()) {
+      return t_tvm_tensormap_->getPointerTo();
     }
     // TODO(tvm-team) consider put storage scope into the pointer type.
     return llvmGetPointerTo(GetLLVMType(ptr->element_type), 
GetGlobalAddressSpace());
   } else if (IsVoidType(type)) {
     return t_void_;
+  } else if (type->IsInstance<TensorMapTypeNode>()) {
+    return t_tvm_tensormap_;
   } else {
     LOG(FATAL) << "Type " << type << " does not have a corresponding LLVM 
Type";
   }
@@ -2292,7 +2298,7 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& 
ty_tir) {
   return GetDebugType(ty_tir, GetLLVMType(ty_tir));
 }
 llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* 
ty_llvm) {
-  if (ty_llvm == nullptr || ty_llvm == t_void_) {
+  if (ty_llvm == nullptr || ty_llvm == t_void_ || ty_llvm == t_tvm_tensormap_) 
{
     return nullptr;
 
   } else if (ty_llvm->isPointerTy()) {
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index f7e4e81903..e1667b6375 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -540,6 +540,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const 
PrimExpr&)>,
   llvm::Type* t_int32_{nullptr};
   llvm::Type* t_int64_{nullptr};
   llvm::Type* t_float64_{nullptr};
+  llvm::ArrayType* t_tvm_tensormap_{nullptr};
   // meta data
   llvm::MDNode* md_very_likely_branch_{nullptr};
   llvm::MDNode* md_tbaa_root_{nullptr};
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 344c0857c4..11f0eaf1ba 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -93,10 +93,24 @@ void CodeGenC::PrintFunctionSignature(const String& 
function_name, const PrimFun
       PrintStorageScope(it->second, os);
     }
 
-    PrintType(GetType(v), os);
+    auto is_tensormap_ptr = [&]() -> bool {
+      if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
+        return ptr->element_type.as<TensorMapTypeNode>();
+      }
+      return false;
+    };
+    if (is_tensormap_ptr()) {
+      os << "const __grid_constant__ CUtensorMap";
+    } else {
+      PrintType(GetType(v), os);
+    }
 
     bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
     bool is_handle = v.dtype().is_handle();
+    auto* ptr = v->type_annotation.as<PointerTypeNode>();
+    if (ptr && ptr->element_type.as<TensorMapTypeNode>()) {
+      is_handle = false;
+    }
     if (no_alias && is_handle) {
       PrintRestrict(v, os);
     }
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index d4e1b785b8..21fbc20f47 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -182,6 +182,8 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, 
std::ostream& os) {
 }
 
 std::string CodeGenCUDA::Finish() {
+  decl_stream << "#include <cuda.h>\n";
+
   if (enable_fp16_) {
     decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
     decl_stream << "#include <cuda_fp16.h>\n";
@@ -194,7 +196,6 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << _cuda_half_t_def;
     decl_stream << "#endif\n\n";
 
-    decl_stream << "#include <cuda.h>\n";
     decl_stream << _cuda_half_util;
   }
 
diff --git a/tests/python/codegen/test_target_codegen_cuda.py 
b/tests/python/codegen/test_target_codegen_cuda.py
index e96217034f..063ed0469b 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -746,5 +746,36 @@ def test_invalid_reinterpret():
         tvm.compile(func, target="cuda")
 
 
[email protected]_cuda
[email protected]_cuda_compute_version(9)
+def test_cuda_tensormap():
+    # fmt: off
+    @T.prim_func
+    def main(A_ptr: T.handle):
+        A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
+
+        A_map: T.handle("tensormap") = T.tvm_stack_alloca("tensormap", 1)
+        T.call_packed("runtime.cuTensorMapInit", A_map, "float32", 2, A.data,
+                      16, 16, 64, 16, 16, 1, 1, 0, 0, 0, 0)
+
+        for blockIdx in T.thread_binding(1, thread="blockIdx.x"):
+            for threadIdx in T.thread_binding(128, thread="threadIdx.x"):
+                if threadIdx == 0:
+                    A[0, 0] = T.reinterpret("float64", A_map)
+    # fmt: on
+
+    mod = tvm.IRModule({"main": main})
+    mod = tvm.compile(mod, target="cuda")
+    assert (
+        """
+extern "C" __global__ void __launch_bounds__(128) main_kernel(float* 
__restrict__ A, const __grid_constant__ CUtensorMap A_map) {
+  if (((int)threadIdx.x) == 0) {
+    A[0] = ((float)(*(double *)(&(A_map))));
+  }
+}""".strip()
+        in mod.mod.imported_modules[0].get_source()
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
index f620610f39..1858c00e86 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
@@ -239,7 +239,8 @@ def test_inject_async_copy_barrier():
         tvm.testing.assert_allclose(B_nd.numpy(), A_np)
 
 
-expected_cuda_script = r"""__forceinline__ __device__ unsigned int
+expected_cuda_script = r"""#include <cuda.h>
+__forceinline__ __device__ unsigned int
 cast_smem_ptr_to_int(const void* const smem_ptr)
 {
   unsigned int smem_int;
@@ -469,6 +470,7 @@ def 
test_cp_async_in_if_then_else(postproc_if_missing_async_support):
     with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
         tvm.compile(mod, target="cuda")
     generated_code = postproc_if_missing_async_support()
+    print(generated_code)
     assert generated_code == expected_cuda_script
 
 

Reply via email to