This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b6db2ec89f [Runtime] CutensorMap support (#18097)
b6db2ec89f is described below
commit b6db2ec89f6cf62562a790b3cc98404ef505c0fa
Author: Bohan Hou <[email protected]>
AuthorDate: Mon Jun 30 07:59:52 2025 -0400
[Runtime] CutensorMap support (#18097)
This PR introduces Cutensor map support in the runtime module. It enables
calling kernels whose arguments are cuTensorMap, these arguments are passed as
handle(address) and associated with arg_extra_tags that indicate indicate it is
tensor map. The TensorMap is allocated on stack with a runtime API
---
include/tvm/ir/type.h | 29 +++
include/tvm/script/ir_builder/tir/ir.h | 2 +
python/tvm/ir/type.py | 16 ++
python/tvm/script/ir_builder/tir/ir.py | 2 +
src/ir/type.cc | 12 ++
src/runtime/cuda/cuda_device_api.cc | 206 +++++++++++++++++++++
src/runtime/cuda/cuda_module.cc | 2 +-
src/runtime/file_utils.cc | 13 ++
src/runtime/meta_data.h | 3 +
src/runtime/pack_args.h | 26 ++-
src/script/ir_builder/tir/ir.cc | 1 +
src/script/printer/tir/expr.cc | 30 +--
src/target/build_common.h | 10 +
src/target/llvm/codegen_cpu.cc | 4 +
src/target/llvm/codegen_llvm.cc | 8 +-
src/target/llvm/codegen_llvm.h | 1 +
src/target/source/codegen_c.cc | 16 +-
src/target/source/codegen_cuda.cc | 3 +-
tests/python/codegen/test_target_codegen_cuda.py | 31 ++++
.../test_tir_transform_inject_ptx_async_copy.py | 4 +-
20 files changed, 396 insertions(+), 23 deletions(-)
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index 5ca35449fc..d864766d7f 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -330,5 +330,34 @@ class FuncType : public Type {
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};
+/*!
+ * \brief The type of tensor map.
+ * \sa TensorMapType
+ */
+class TensorMapTypeNode : public TypeNode {
+ public:
+ void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
+
+ bool SEqualReduce(const TensorMapTypeNode* other, SEqualReducer equal) const
{
+ return equal(span, other->span);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(span); }
+
+ static constexpr const char* _type_key = "TensorMapType";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorMapTypeNode, TypeNode);
+};
+
+/*!
+ * \brief Managed reference to TensorMapTypeNode.
+ * \sa TensorMapTypeNode
+ */
+class TensorMapType : public Type {
+ public:
+ TVM_DLL TensorMapType(Span span = Span());
+
+ TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TensorMapType,
Type, TensorMapTypeNode);
+};
+
} // namespace tvm
#endif // TVM_IR_TYPE_H_
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index febdac55d9..30b5bb3382 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -452,6 +452,8 @@ inline Var Handle(runtime::DataType dtype =
runtime::DataType::Void(),
return is_size_var ? tvm::tir::SizeVar("", type_annotation) :
tvm::tir::Var("", type_annotation);
}
+inline Var TensormapHandle() { return tvm::tir::Var("",
PointerType(TensorMapType())); }
+
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)
\
inline PrimExpr FuncName(Optional<PrimExpr> expr = std::nullopt, bool
is_size_var = false) { \
DataType dtype = DType;
\
diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py
index 9ec8ef8fbd..d0bf7014e2 100644
--- a/python/tvm/ir/type.py
+++ b/python/tvm/ir/type.py
@@ -107,3 +107,19 @@ class FuncType(Type):
arg_types,
ret_type,
)
+
+
[email protected]_object("TensorMapType")
+class TensorMapType(Type):
+ """TensorMapType used in the low-level TIR.
+
+ Parameters
+ ----------
+ span : tvm.ir.Span
+ The span information.
+ """
+
+ def __init__(self, span=None):
+ self.__init_handle_by_constructor__(
+ _ffi_api.TensorMapType, span # pylint: disable=no-member
+ )
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index c7589f4a19..5864de2cac 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1566,6 +1566,8 @@ def handle(
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
+ if dtype == "tensormap":
+ return _ffi_api.TensormapHandle() # type: ignore[attr-defined] #
pylint: disable=no-member
is_unknown_type = dtype is None
if dtype is None:
dtype = "void"
diff --git a/src/ir/type.cc b/src/ir/type.cc
index 95b65475be..cd7b6a523c 100644
--- a/src/ir/type.cc
+++ b/src/ir/type.cc
@@ -89,4 +89,16 @@
TVM_FFI_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array<Type> fields) {
return TupleType(fields);
});
+TVM_FFI_REGISTER_GLOBAL("ir.TensorMapType").set_body_typed([](Span span) {
+ return TensorMapType(span);
+});
+
+TensorMapType::TensorMapType(Span span) {
+ ObjectPtr<TensorMapTypeNode> n = make_object<TensorMapTypeNode>();
+ n->span = std::move(span);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(TensorMapTypeNode);
+
} // namespace tvm
diff --git a/src/runtime/cuda/cuda_device_api.cc
b/src/runtime/cuda/cuda_device_api.cc
index 399312e193..98a83f4ed7 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -357,5 +357,211 @@ TVM_DLL int GetCudaDeviceCount() {
TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDeviceCount);
+/**
+ * \brief FFI wrapper for cuTensorMapEncodeTiled.
+ *
+ * This function registers a global function `runtime.cuTensorMapEncodeTiled`
that can be
+ * called from other parts of the TVM runtime (e.g., Python). It wraps the
CUDA Driver API
+ * function `cuTensorMapEncodeTiled`, which initializes a tensor map
descriptor (CUtensorMap).
+ *
+ * \param tensor_map (handle): A `void*` pointer to the CUtensorMap object to
be initialized.
+ * \param tensor_dtype (DataType): The TVM data type of the tensor.
+ * \param tensor_rank (int): The rank (number of dimensions) of the tensor.
+ * \param tensor_ptr (handle): A `void*` pointer to the start of the tensor in
global memory.
+ * \param global_shape (int...): `tensor_rank` integer arguments for the
global tensor dimensions.
+ * \param global_strides (int...): `tensor_rank - 1` integer arguments for the
global tensor
+ * strides. The stride for the innermost dimension is not provided as it's
assumed to be contiguous.
+ * \param shared_shape (int...): `tensor_rank` integer arguments for the shape
of the tile (box)
+ * in shared memory.
+ * \param shared_strides (int...): `tensor_rank` integer arguments for the
strides of the tile (box)
+ * in shared memory.
+ * \param interleaved_kind (int): An integer corresponding to the
CUtensorMapInterleave enum.
+ * \param swizzle_kind (int): An integer corresponding to the
CUtensorMapSwizzle enum.
+ * \param l2_promotion_kind (int): An integer corresponding to the
CUtensorMapL2promotion enum.
+ * \param oob_fill_kind (int): An integer corresponding to the
CUtensorMapFloatOOBfill enum.
+ */
+TVM_FFI_REGISTER_GLOBAL("runtime.cuTensorMapEncodeTiled")
+ .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
+ CHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4
arguments";
+ size_t arg_cnt = 0;
+ CUtensorMap* tensor_map =
static_cast<CUtensorMap*>(args[arg_cnt++].cast<void*>());
+ runtime::DataType tensor_dtype =
args[arg_cnt++].cast<runtime::DataType>();
+ uint32_t tensor_rank =
static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
+ void* tensor_ptr = static_cast<void*>(args[arg_cnt++].cast<void*>());
+
+ CHECK_EQ(args.size(), 4 + tensor_rank * 4 + 3)
+ << "cuTensorMapEncodeTiled expects " << 4 + tensor_rank * 4 + 3 << "
arguments"
+ << "tensor_map, tensor_dtype, tensor_rank, tensor_ptr,
global_shape(" << tensor_rank
+ << "), global_strides(" << tensor_rank - 1 << "), shared_shape(" <<
tensor_rank
+ << "), shared_strides(" << tensor_rank << "), interleaved_kind,
swizzle_kind"
+ << ", l2_promotion_kind, oob_fill_kind";
+
+ std::vector<cuuint64_t> global_shape(tensor_rank);
+ std::vector<cuuint64_t> global_strides(tensor_rank);
+ std::vector<uint32_t> shared_shape(tensor_rank);
+ std::vector<uint32_t> shared_strides(tensor_rank);
+ for (size_t i = 0; i < tensor_rank; ++i) {
+ global_shape[i] =
static_cast<cuuint64_t>(args[arg_cnt++].cast<int64_t>());
+ }
+ for (size_t i = 0; i < tensor_rank - 1; ++i) {
+ global_strides[i] =
static_cast<cuuint64_t>(args[arg_cnt++].cast<int64_t>());
+ CHECK_EQ(global_strides[i] % 16, 0) << "global strides must be
multiple of 16";
+ }
+ for (size_t i = 0; i < tensor_rank; ++i) {
+ shared_shape[i] =
static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
+ CHECK_GE(shared_shape[i], 0) << "boxDim must be non-negative";
+ CHECK_LE(shared_shape[i], 256) << "boxDim must be less than or equal
to 256";
+ }
+ for (size_t i = 0; i < tensor_rank; ++i) {
+ shared_strides[i] =
static_cast<uint32_t>(args[arg_cnt++].cast<int32_t>());
+ }
+ auto interleaved_kind =
static_cast<CUtensorMapInterleave>(args[arg_cnt++].cast<int>());
+ auto swizzle_kind =
static_cast<CUtensorMapSwizzle>(args[arg_cnt++].cast<int>());
+ auto l2_promotion_kind =
static_cast<CUtensorMapL2promotion>(args[arg_cnt++].cast<int>());
+ auto oob_fill_kind =
static_cast<CUtensorMapFloatOOBfill>(args[arg_cnt++].cast<int>());
+
+ ICHECK_EQ(tensor_dtype.lanes(), 1)
+ << "Expect tensor_dtype to have lanes=1, but get " << tensor_dtype;
+ CUtensorMapDataType cu_dtype;
+ switch (tensor_dtype.code()) {
+ case DataType::kInt:
+ // int
+ switch (tensor_dtype.bits()) {
+ case 8:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
+ break;
+ case 32:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT32;
+ break;
+ case 64:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT64;
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data type " <<
runtime::DLDataTypeToString(tensor_dtype);
+ }
+ break;
+ case DataType::kUInt:
+ // unsigned int
+ switch (tensor_dtype.bits()) {
+ case 8:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
+ break;
+ case 16:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT16;
+ break;
+ case 32:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT32;
+ break;
+ case 64:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT64;
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data type " <<
runtime::DLDataTypeToString(tensor_dtype);
+ }
+ break;
+ case DataType::kFloat:
+ // float
+ switch (tensor_dtype.bits()) {
+ case 16:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
+ break;
+ case 32:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
+ break;
+ case 64:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data type " <<
runtime::DLDataTypeToString(tensor_dtype);
+ }
+ break;
+ case DataType::kBFloat:
+ // bfloat
+ switch (tensor_dtype.bits()) {
+ case 16:
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data type " <<
runtime::DLDataTypeToString(tensor_dtype);
+ }
+ break;
+ case DataType::kFloat8_e4m3fn:
+ // NV float8 e4m3
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
+ break;
+ case DataType::kFloat8_e5m2:
+ // NV float8 e5m2
+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data type " <<
runtime::DLDataTypeToString(tensor_dtype);
+ }
+
+ // sanity checks per cuTensorMapEncodeTiled requirements
+ // see
+ //
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
+ CHECK_EQ((reinterpret_cast<uint64_t>(tensor_ptr) & 0b1111), 0); //
16-byte alignment
+ CHECK_EQ((reinterpret_cast<uint64_t>(tensor_map) & 0b111111), 0); //
64-byte alignment
+ CHECK_LE(tensor_rank, 5) << "cuTensorMapEncodeTiled only supports up to
5D tensors";
+
+ if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) {
+ CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 32)
+ << "CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner
dimension will be <= 32.";
+ } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) {
+ CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 64)
+ << "CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner
dimension will be <= 64.";
+ } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_128B) {
+ CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 128)
+ << "CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner
dimension will be <= "
+ "128.";
+ }
+
+ const cuuint64_t* global_shape_ptr = global_shape.data();
+ const cuuint64_t* global_strides_ptr = global_strides.data();
+ const uint32_t* shared_shape_ptr = shared_shape.data();
+ const uint32_t* shared_strides_ptr = shared_strides.data();
+
+ CUresult res =
+ cuTensorMapEncodeTiled(tensor_map, cu_dtype, tensor_rank,
tensor_ptr, global_shape_ptr,
+ global_strides_ptr, shared_shape_ptr,
shared_strides_ptr,
+ interleaved_kind, swizzle_kind,
l2_promotion_kind, oob_fill_kind);
+ const char* errstr;
+ cuGetErrorString(res, &errstr);
+ if (res != CUDA_SUCCESS) {
+ // get error string
+ const char* error_string = nullptr;
+ cuGetErrorString(res, &error_string);
+ std::cerr << "Error in cuTensorMapEncodeTiled: " << error_string <<
std::endl;
+ std::cout << "cu_dtype: " << cu_dtype << "\n";
+ std::cout << "TMA Desc Addr: " << tensor_map << "\n";
+ std::cout << "TMA Interleave: " << interleaved_kind << "\n";
+ std::cout << "TMA L2Promotion: " << l2_promotion_kind << "\n";
+ std::cout << "TMA OOBFill: " << oob_fill_kind << "\n";
+ std::cout << "SMEM Swizzle: " << swizzle_kind << "\n";
+ std::cout << "tensor rank: " << tensor_rank << "\n";
+ std::cout << "global prob shape: ";
+ for (size_t i = 0; i < tensor_rank; i++) {
+ std::cout << global_shape[i] << " ";
+ }
+ std::cout << "\n";
+ std::cout << "global prob stride: ";
+ for (size_t i = 0; i < tensor_rank; i++) {
+ std::cout << global_strides[i] << " ";
+ }
+ std::cout << "\n";
+ std::cout << "smem box shape: ";
+ for (size_t i = 0; i < tensor_rank; i++) {
+ std::cout << shared_shape[i] << " ";
+ }
+ std::cout << "\n";
+ std::cout << "smem box stride: ";
+ for (size_t i = 0; i < tensor_rank; i++) {
+ std::cout << shared_strides[i] << " ";
+ }
+ std::cout << "\n";
+ CHECK_EQ(res, CUDA_SUCCESS) << "Error in cuTensorMapEncodeTiled: " <<
errstr;
+ }
+ });
+
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index a29d303acf..6d69fde5cd 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -266,7 +266,7 @@ ffi::Function CUDAModuleNode::GetFunction(const String&
name,
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(),
info.launch_param_tags);
- return PackFuncVoidAddr(f, info.arg_types);
+ return PackFuncVoidAddr(f, info.arg_types, info.arg_extra_tags);
}
Module CUDAModuleCreate(std::string data, std::string fmt,
diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc
index 2aa377b9f8..513efbd9fb 100644
--- a/src/runtime/file_utils.cc
+++ b/src/runtime/file_utils.cc
@@ -45,6 +45,11 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags);
+ std::vector<int> iarg_extra_tags(arg_extra_tags.size());
+ for (size_t i = 0; i < arg_extra_tags.size(); ++i) {
+ iarg_extra_tags[i] = static_cast<int>(arg_extra_tags[i]);
+ }
+ writer->WriteObjectKeyValue("arg_extra_tags", iarg_extra_tags);
writer->EndObject();
}
@@ -56,6 +61,12 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
helper.DeclareOptionalField("launch_param_tags", &launch_param_tags);
helper.DeclareOptionalField("thread_axis_tags",
&launch_param_tags); // for backward
compatibility
+ std::vector<int> iarg_extra_tags;
+ helper.DeclareOptionalField("arg_extra_tags", &iarg_extra_tags);
+ arg_extra_tags.resize(iarg_extra_tags.size());
+ for (size_t i = 0; i < arg_extra_tags.size(); ++i) {
+ arg_extra_tags[i] = static_cast<ArgExtraTags>(iarg_extra_tags[i]);
+ }
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
@@ -67,12 +78,14 @@ void FunctionInfo::Save(dmlc::Stream* writer) const {
writer->Write(name);
writer->Write(arg_types);
writer->Write(launch_param_tags);
+ writer->Write(arg_extra_tags);
}
bool FunctionInfo::Load(dmlc::Stream* reader) {
if (!reader->Read(&name)) return false;
if (!reader->Read(&arg_types)) return false;
if (!reader->Read(&launch_param_tags)) return false;
+ if (!reader->Read(&arg_extra_tags)) return false;
return true;
}
diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h
index 51120c1f9e..8acefecaad 100644
--- a/src/runtime/meta_data.h
+++ b/src/runtime/meta_data.h
@@ -59,6 +59,9 @@ struct FunctionInfo {
std::vector<DLDataType> arg_types;
std::vector<std::string> launch_param_tags;
+ enum class ArgExtraTags : int { kNone = 0, kTensorMap = 1 };
+ std::vector<ArgExtraTags> arg_extra_tags;
+
void Save(dmlc::JSONWriter* writer) const;
void Load(dmlc::JSONReader* reader);
void Save(dmlc::Stream* writer) const;
diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h
index 0068db51d5..8929f90b0f 100644
--- a/src/runtime/pack_args.h
+++ b/src/runtime/pack_args.h
@@ -64,12 +64,15 @@ union ArgUnion64 {
*
* \param f with signiture (ffi::PackedArgs args, ffi::Any* rv, void*
void_args)
* \param arg_types The arguments type information.
+ * \param arg_extra_tags extra tags for the arguments
* \tparam F the function type
*
* \return The wrapped packed function.
*/
template <typename F>
-inline ffi::Function PackFuncVoidAddr(F f, const std::vector<DLDataType>&
arg_types);
+inline ffi::Function PackFuncVoidAddr(
+ F f, const std::vector<DLDataType>& arg_types,
+ const std::vector<FunctionInfo::ArgExtraTags>& arg_extra_tags = {});
/*!
* \brief Create a packed function that from function only packs buffer
arguments.
*
@@ -130,7 +133,8 @@ enum ArgConvertCode {
INT64_TO_UINT32,
FLOAT64_TO_FLOAT32,
FLOAT64_TO_FLOAT64,
- HANDLE_TO_HANDLE
+ HANDLE_TO_HANDLE,
+ HANDLE_TO_TENSORMAP
};
inline ArgConvertCode GetArgConvertCode(DLDataType t) {
@@ -183,6 +187,10 @@ inline ffi::Function PackFuncVoidAddr_(F f, const
std::vector<ArgConvertCode>& c
addr[i] = &(holder[i]);
break;
}
+ case HANDLE_TO_TENSORMAP: {
+ addr[i] = raw_args[i].v_ptr;
+ break;
+ }
}
}
f(args, ret, addr);
@@ -222,7 +230,8 @@ inline ffi::Function PackFuncNonBufferArg_(F f, int base,
holder[i].v_float32[0] = static_cast<float>(raw_args[base +
i].v_float64);
break;
}
- case HANDLE_TO_HANDLE: {
+ case HANDLE_TO_HANDLE:
+ case HANDLE_TO_TENSORMAP: {
LOG(FATAL) << "not reached";
break;
}
@@ -284,6 +293,7 @@ inline ffi::Function PackFuncPackedArgAligned_(F f, const
std::vector<ArgConvert
++ptr;
break;
}
+ case HANDLE_TO_TENSORMAP:
default: {
LOG(FATAL) << "not reached";
break;
@@ -297,10 +307,16 @@ inline ffi::Function PackFuncPackedArgAligned_(F f, const
std::vector<ArgConvert
} // namespace detail
template <typename F>
-inline ffi::Function PackFuncVoidAddr(F f, const std::vector<DLDataType>&
arg_types) {
+inline ffi::Function PackFuncVoidAddr(
+ F f, const std::vector<DLDataType>& arg_types,
+ const std::vector<FunctionInfo::ArgExtraTags>& arg_extra_tags) {
std::vector<detail::ArgConvertCode> codes(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
- codes[i] = detail::GetArgConvertCode(arg_types[i]);
+ if (arg_extra_tags.size() > i && arg_extra_tags[i] ==
FunctionInfo::ArgExtraTags::kTensorMap) {
+ codes[i] = detail::HANDLE_TO_TENSORMAP;
+ } else {
+ codes[i] = detail::GetArgConvertCode(arg_types[i]);
+ }
}
size_t num_void_args = arg_types.size();
// specialization
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 79e14feee4..7ef970fa09 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -830,6 +830,7 @@
TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN
TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
+TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.TensormapHandle").set_body_typed(TensormapHandle);
TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.min")
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index 549247449e..d2f02f7908 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -37,19 +37,23 @@ ExprDoc PrintVarCreation(const tir::Var& var, const
ObjectPath& var_p, const IRD
}
if (const auto* ptr_type = type.as<PointerTypeNode>()) {
- const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
- ICHECK(prim_type);
- ExprDoc element_type =
- LiteralDoc::DataType(prim_type->dtype,
type_p->Attr("element_type")->Attr("dtype"));
- rhs = TIR(d, "handle");
- rhs->source_paths.push_back(var_p->Attr("dtype"));
- if (ptr_type->storage_scope == "") {
- rhs = rhs->Call({element_type}, kwargs_keys, kwargs_values);
- } else {
- rhs = rhs->Call({element_type,
- LiteralDoc::Str(ptr_type->storage_scope, //
- type_p->Attr("storage_scope"))},
- kwargs_keys, kwargs_values);
+ if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
+ ExprDoc element_type =
+ LiteralDoc::DataType(prim_type->dtype,
type_p->Attr("element_type")->Attr("dtype"));
+ rhs = TIR(d, "handle");
+ rhs->source_paths.push_back(var_p->Attr("dtype"));
+ if (ptr_type->storage_scope == "") {
+ rhs = rhs->Call({element_type}, kwargs_keys, kwargs_values);
+ } else {
+ rhs = rhs->Call({element_type,
+ LiteralDoc::Str(ptr_type->storage_scope, //
+ type_p->Attr("storage_scope"))},
+ kwargs_keys, kwargs_values);
+ }
+ } else if (ptr_type->element_type->IsInstance<TensorMapTypeNode>()) {
+ rhs = TIR(d, "handle")
+ ->Call({LiteralDoc::Str("tensormap",
type_p->Attr("element_type")->Attr("dtype"))},
+ {}, {});
}
} else {
rhs = TIR(d, DType2Str(var->dtype));
diff --git a/src/target/build_common.h b/src/target/build_common.h
index 70f15d091e..fda7e2e67c 100644
--- a/src/target/build_common.h
+++ b/src/target/build_common.h
@@ -49,6 +49,16 @@ inline std::unordered_map<std::string,
runtime::FunctionInfo> ExtractFuncInfo(co
runtime::FunctionInfo info;
for (size_t i = 0; i < f->params.size(); ++i) {
info.arg_types.push_back(f->params[i].dtype());
+ auto is_tensormap = [](const tir::Var& var) -> bool {
+ const auto* type = var->type_annotation.as<PointerTypeNode>();
+ if (type == nullptr) {
+ return false;
+ }
+ return type->element_type.as<TensorMapTypeNode>() != nullptr;
+ };
+ info.arg_extra_tags.push_back(is_tensormap(f->params[i])
+ ?
runtime::FunctionInfo::ArgExtraTags::kTensorMap
+ :
runtime::FunctionInfo::ArgExtraTags::kNone);
}
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) {
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 76e825ab75..b16617e3d6 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -1037,6 +1037,10 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode*
op) {
return builder_->CreateAlloca(t_tvm_ffi_any_, num);
} else if (type == "array") {
return builder_->CreateAlloca(t_tvm_array_, num);
+ } else if (type == "tensormap") {
+ auto* alloca = builder_->CreateAlloca(t_tvm_tensormap_, num);
+ alloca->setAlignment(llvm::Align(64));
+ return alloca;
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
}
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index e9bcfa97fd..45dafa85b9 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -154,6 +154,8 @@ void CodeGenLLVM::Init(const std::string& module_name,
LLVMTarget* llvm_target,
t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_int64_ = llvm::Type::getInt64Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx);
+ // CUTensorMap is a 128 byte struct, so we use a 128 byte array to represent
it.
+ t_tvm_tensormap_ = llvm::ArrayType::get(t_char_, 128);
// meta data
md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1);
md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
@@ -620,11 +622,15 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type)
const {
if (primtype->dtype.is_void() || primtype->dtype.code() >=
DataType::kCustomBegin) {
return t_void_p_;
}
+ } else if (ptr->element_type->IsInstance<TensorMapTypeNode>()) {
+ return t_tvm_tensormap_->getPointerTo();
}
// TODO(tvm-team) consider put storage scope into the pointer type.
return llvmGetPointerTo(GetLLVMType(ptr->element_type),
GetGlobalAddressSpace());
} else if (IsVoidType(type)) {
return t_void_;
+ } else if (type->IsInstance<TensorMapTypeNode>()) {
+ return t_tvm_tensormap_;
} else {
LOG(FATAL) << "Type " << type << " does not have a corresponding LLVM
Type";
}
@@ -2292,7 +2298,7 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type&
ty_tir) {
return GetDebugType(ty_tir, GetLLVMType(ty_tir));
}
llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type*
ty_llvm) {
- if (ty_llvm == nullptr || ty_llvm == t_void_) {
+ if (ty_llvm == nullptr || ty_llvm == t_void_ || ty_llvm == t_tvm_tensormap_)
{
return nullptr;
} else if (ty_llvm->isPointerTy()) {
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index f7e4e81903..e1667b6375 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -540,6 +540,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const
PrimExpr&)>,
llvm::Type* t_int32_{nullptr};
llvm::Type* t_int64_{nullptr};
llvm::Type* t_float64_{nullptr};
+ llvm::ArrayType* t_tvm_tensormap_{nullptr};
// meta data
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 344c0857c4..11f0eaf1ba 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -93,10 +93,24 @@ void CodeGenC::PrintFunctionSignature(const String&
function_name, const PrimFun
PrintStorageScope(it->second, os);
}
- PrintType(GetType(v), os);
+ auto is_tensormap_ptr = [&]() -> bool {
+ if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
+ return ptr->element_type.as<TensorMapTypeNode>();
+ }
+ return false;
+ };
+ if (is_tensormap_ptr()) {
+ os << "const __grid_constant__ CUtensorMap";
+ } else {
+ PrintType(GetType(v), os);
+ }
bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
bool is_handle = v.dtype().is_handle();
+ auto* ptr = v->type_annotation.as<PointerTypeNode>();
+ if (ptr && ptr->element_type.as<TensorMapTypeNode>()) {
+ is_handle = false;
+ }
if (no_alias && is_handle) {
PrintRestrict(v, os);
}
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index d4e1b785b8..21fbc20f47 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -182,6 +182,8 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f,
std::ostream& os) {
}
std::string CodeGenCUDA::Finish() {
+ decl_stream << "#include <cuda.h>\n";
+
if (enable_fp16_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
decl_stream << "#include <cuda_fp16.h>\n";
@@ -194,7 +196,6 @@ std::string CodeGenCUDA::Finish() {
decl_stream << _cuda_half_t_def;
decl_stream << "#endif\n\n";
- decl_stream << "#include <cuda.h>\n";
decl_stream << _cuda_half_util;
}
diff --git a/tests/python/codegen/test_target_codegen_cuda.py
b/tests/python/codegen/test_target_codegen_cuda.py
index e96217034f..063ed0469b 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -746,5 +746,36 @@ def test_invalid_reinterpret():
tvm.compile(func, target="cuda")
[email protected]_cuda
[email protected]_cuda_compute_version(9)
+def test_cuda_tensormap():
+ # fmt: off
+ @T.prim_func
+ def main(A_ptr: T.handle):
+ A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
+
+ A_map: T.handle("tensormap") = T.tvm_stack_alloca("tensormap", 1)
+ T.call_packed("runtime.cuTensorMapInit", A_map, "float32", 2, A.data,
+ 16, 16, 64, 16, 16, 1, 1, 0, 0, 0, 0)
+
+ for blockIdx in T.thread_binding(1, thread="blockIdx.x"):
+ for threadIdx in T.thread_binding(128, thread="threadIdx.x"):
+ if threadIdx == 0:
+ A[0, 0] = T.reinterpret("float64", A_map)
+ # fmt: on
+
+ mod = tvm.IRModule({"main": main})
+ mod = tvm.compile(mod, target="cuda")
+ assert (
+ """
+extern "C" __global__ void __launch_bounds__(128) main_kernel(float*
__restrict__ A, const __grid_constant__ CUtensorMap A_map) {
+ if (((int)threadIdx.x) == 0) {
+ A[0] = ((float)(*(double *)(&(A_map))));
+ }
+}""".strip()
+ in mod.mod.imported_modules[0].get_source()
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
index f620610f39..1858c00e86 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
@@ -239,7 +239,8 @@ def test_inject_async_copy_barrier():
tvm.testing.assert_allclose(B_nd.numpy(), A_np)
-expected_cuda_script = r"""__forceinline__ __device__ unsigned int
+expected_cuda_script = r"""#include <cuda.h>
+__forceinline__ __device__ unsigned int
cast_smem_ptr_to_int(const void* const smem_ptr)
{
unsigned int smem_int;
@@ -469,6 +470,7 @@ def
test_cp_async_in_if_then_else(postproc_if_missing_async_support):
with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
tvm.compile(mod, target="cuda")
generated_code = postproc_if_missing_async_support()
+ print(generated_code)
assert generated_code == expected_cuda_script