This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 775e05064b [DataType] Rename FP8 dtypes to standard names (#17712)
775e05064b is described below
commit 775e05064b52c846390120d9613e7a89bb52aef7
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Mar 6 17:35:49 2025 -0500
[DataType] Rename FP8 dtypes to standard names (#17712)
This PR renames the FP8 dtypes in TVM according to standards:
* `e4m3_float8` is renamed to `float8_e4m3fn`,
* `e5m2_float8` is renamed to `float8_e5m2`.
This aligns with dtype names in PyTorch and ml_dtypes.
---
include/tvm/runtime/data_type.h | 77 ++++++++++++++--------
include/tvm/script/ir_builder/tir/ir.h | 6 +-
python/tvm/_ffi/runtime_ctypes.py | 42 ++++++------
python/tvm/contrib/tvmjs.py | 8 +--
python/tvm/relax/backend/cuda/cublas.py | 10 +--
python/tvm/relax/backend/rocm/hipblas.py | 8 +--
python/tvm/runtime/ndarray.py | 8 +--
python/tvm/script/ir_builder/tir/ir.py | 66 +++++++++----------
python/tvm/tir/tensor_intrin/cuda.py | 38 +++++------
src/ir/expr.cc | 7 +-
src/relax/backend/contrib/utils.h | 10 +--
src/runtime/contrib/cublas/cublas.cc | 5 +-
src/runtime/ndarray.cc | 2 +-
src/script/ir_builder/tir/ir.cc | 12 ++--
src/support/scalars.h | 4 +-
src/target/llvm/codegen_llvm.cc | 4 +-
src/target/source/codegen_cuda.cc | 16 ++---
src/tir/op/op.cc | 8 +--
src/tir/transforms/dtype_conversion.h | 2 +-
.../python/codegen/test_target_codegen_cuda_fp8.py | 44 ++++++-------
tests/python/ir/test_datatype_nv_fp8.py | 8 ++-
tests/python/ir/test_dtype.py | 9 +--
tests/python/relax/test_codegen_cublas.py | 16 ++---
tests/python/relax/test_op_inspect.py | 6 +-
tests/python/relax/test_op_qdq.py | 12 ++--
..._tir_schedule_tensorize_ldmatrix_mma_numeric.py | 27 ++++----
.../test_tir_transform_fp8_legalize.py | 16 ++---
.../python/tvmscript/test_tvmscript_printer_tir.py | 11 ++--
28 files changed, 255 insertions(+), 227 deletions(-)
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 76e5e3833f..65fd0c98fd 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -56,9 +56,9 @@ class DataType {
kFloat = kDLFloat,
kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
kBFloat = kDLBfloat,
- kE4M3Float = 6U,
- kE5M2Float = 7U,
- kFloat4E2M1Fn = 8U,
+ kFloat8_e4m3fn = 6U,
+ kFloat8_e5m2 = 7U,
+ kFloat4_e2m1fn = 8U,
kCustomBegin = 129
};
/*! \brief default constructor */
@@ -85,10 +85,10 @@ class DataType {
if (code == kBFloat) {
ICHECK_EQ(bits, 16);
}
- if (code == kE4M3Float || code == kE5M2Float) {
+ if (code == kFloat8_e4m3fn || code == kFloat8_e5m2) {
ICHECK_EQ(bits, 8);
}
- if (code == kFloat4E2M1Fn) {
+ if (code == kFloat4_e2m1fn) {
ICHECK_EQ(bits, 4);
}
}
@@ -126,15 +126,15 @@ class DataType {
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a float8 type. */
bool is_float8() const {
- return (code() == DataType::kFloat || code() == DataType::kE4M3Float ||
- code() == DataType::kE5M2Float) &&
+ return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn ||
+ code() == DataType::kFloat8_e5m2) &&
bits() == 8;
}
/*! \return whether type is a float4 type. */
- bool is_float4() const { return code() == DataType::kFloat4E2M1Fn && bits()
== 4; }
- bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float &&
bits() == 8); }
- bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float &&
bits() == 8); }
- bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4E2M1Fn &&
bits() == 4); }
+ bool is_float4() const { return code() == DataType::kFloat4_e2m1fn && bits()
== 4; }
+ bool is_float8_e4m3fn() const { return (code() == DataType::kFloat8_e4m3fn
&& bits() == 8); }
+ bool is_float8_e5m2() const { return (code() == DataType::kFloat8_e5m2 &&
bits() == 8); }
+ bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4_e2m1fn
&& bits() == 4); }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
@@ -252,19 +252,19 @@ class DataType {
* \param lanes The number of lanes
* \return The constructed data type.
*/
- static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kE4M3Float, 8,
lanes); }
+ static DataType NVFloat8E4M3(int lanes = 1) { return
DataType(kFloat8_e4m3fn, 8, lanes); }
/*!
* \brief Construct NV float8 e5m2 datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
- static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8,
lanes); }
+ static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2,
8, lanes); }
/*!
* \brief Construct NV float4_e2m1fn datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
- static DataType NVFloat4E2M1FN(int lanes = 1) { return
DataType(kFloat4E2M1Fn, 4, lanes); }
+ static DataType NVFloat4E2M1FN(int lanes = 1) { return
DataType(kFloat4_e2m1fn, 4, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes.
@@ -393,11 +393,11 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode
type_code) {
return "handle";
case kDLBfloat:
return "bfloat";
- case DataType::kE4M3Float:
- return "e4m3_float";
- case DataType::kE5M2Float:
- return "e5m2_float";
- case DataType::kFloat4E2M1Fn:
+ case DataType::kFloat8_e4m3fn:
+ return "float8_e4m3fn";
+ case DataType::kFloat8_e5m2:
+ return "float8_e5m2";
+ case DataType::kFloat4_e2m1fn:
return "float4_e2m1fn";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
@@ -420,7 +420,10 @@ inline std::ostream& operator<<(std::ostream& os,
DLDataType t) { // NOLINT(*)
}
if (t.code == kTVMOpaqueHandle) return os;
int16_t lanes = static_cast<int16_t>(t.lanes);
- os << static_cast<int>(t.bits);
+ if (t.code != DataType::kFloat8_e4m3fn && t.code != DataType::kFloat8_e5m2 &&
+ t.code != DataType::kFloat4_e2m1fn) {
+ os << static_cast<int>(t.bits);
+ }
if (lanes > 1) {
os << 'x' << lanes;
} else if (lanes < -1) {
@@ -458,7 +461,7 @@ inline DLDataType String2DLDataType(std::string s) {
scan = s.c_str() + 4;
} else if (s.substr(0, 13) == "float4_e2m1fn") {
// Avoid being treated as "float"
- t.code = DataType::kFloat4E2M1Fn;
+ t.code = DataType::kFloat4_e2m1fn;
t.bits = 4;
scan = s.c_str() + 13;
char* endpt = nullptr;
@@ -468,6 +471,30 @@ inline DLDataType String2DLDataType(std::string s) {
}
ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
return t;
+ } else if (s.substr(0, 13) == "float8_e4m3fn") {
+ // Avoid being treated as "float"
+ t.code = DataType::kFloat8_e4m3fn;
+ t.bits = 8;
+ scan = s.c_str() + 13;
+ char* endpt = nullptr;
+ if (*scan == 'x') {
+ t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
+ scan = endpt;
+ }
+ ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
+ return t;
+ } else if (s.substr(0, 11) == "float8_e5m2") {
+ // Avoid being treated as "float"
+ t.code = DataType::kFloat8_e5m2;
+ t.bits = 8;
+ scan = s.c_str() + 11;
+ char* endpt = nullptr;
+ if (*scan == 'x') {
+ t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
+ scan = endpt;
+ }
+ ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
+ return t;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat;
scan = s.c_str() + 5;
@@ -484,14 +511,6 @@ inline DLDataType String2DLDataType(std::string s) {
t.code = DataType::kBFloat;
t.bits = 16;
scan = s.c_str() + 6;
- } else if (s.substr(0, 10) == "e4m3_float") {
- t.code = DataType::kE4M3Float;
- t.bits = 8;
- scan = s.c_str() + 10;
- } else if (s.substr(0, 10) == "e5m2_float") {
- t.code = DataType::kE5M2Float;
- t.bits = 8;
- scan = s.c_str() + 10;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index e78e0d51fd..e60a3859ac 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -502,10 +502,10 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int,
DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8,
DataType::NVFloat8E4M3);
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8,
DataType::NVFloat8E5M2);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN,
DataType::NVFloat8E4M3);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2,
DataType::NVFloat8E5M2);
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1fn,
DataType::NVFloat4E2M1FN);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN,
DataType::NVFloat4E2M1FN);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
diff --git a/python/tvm/_ffi/runtime_ctypes.py
b/python/tvm/_ffi/runtime_ctypes.py
index 3f4ceadd1d..317bd6bead 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -66,9 +66,9 @@ class DataTypeCode(object):
FLOAT = 2
HANDLE = 3
BFLOAT = 4
- E4M3Float = 6
- E5M2Float = 7
- FLOAT4E2M1FN = 8
+ Float8E4M3FN = 6
+ Float8E5M2 = 7
+ Float4E2M1FN = 8
class DataType(ctypes.Structure):
@@ -81,9 +81,9 @@ class DataType(ctypes.Structure):
DataTypeCode.FLOAT: "float",
DataTypeCode.HANDLE: "handle",
DataTypeCode.BFLOAT: "bfloat",
- DataTypeCode.E4M3Float: "e4m3_float",
- DataTypeCode.E5M2Float: "e5m2_float",
- DataTypeCode.FLOAT4E2M1FN: "float4_e2m1fn",
+ DataTypeCode.Float8E4M3FN: "float8_e4m3fn",
+ DataTypeCode.Float8E5M2: "float8_e5m2",
+ DataTypeCode.Float4E2M1FN: "float4_e2m1fn",
}
NUMPY2STR = {
np.dtype(np.bool_): "bool",
@@ -112,9 +112,9 @@ class DataType(ctypes.Structure):
"uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1},
"uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1},
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
- "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8,
"lanes": 1},
- "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8,
"lanes": 1},
- "float4_e2m1fn": {"type_code": DataTypeCode.FLOAT4E2M1FN, "bits": 4,
"lanes": 1},
+ "float8_e4m3fn": {"type_code": DataTypeCode.Float8E4M3FN, "bits": 8,
"lanes": 1},
+ "float8_e5m2": {"type_code": DataTypeCode.Float8E5M2, "bits": 8,
"lanes": 1},
+ "float4_e2m1fn": {"type_code": DataTypeCode.Float4E2M1FN, "bits": 4,
"lanes": 1},
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
@@ -157,9 +157,17 @@ class DataType(ctypes.Structure):
head = head[4:]
elif head.startswith("float4_e2m1fn"):
# Avoid being treated as "float"
- self.type_code = DataTypeCode.FLOAT4E2M1FN
+ self.type_code = DataTypeCode.Float4E2M1FN
bits = 4
head = ""
+ elif head.startswith("float8_e4m3fn"):
+ self.type_code = DataTypeCode.Float8E4M3FN
+ bits = 8
+ head = ""
+ elif head.startswith("float8_e5m2"):
+ self.type_code = DataTypeCode.Float8E5M2
+ bits = 8
+ head = ""
elif head.startswith("float"):
self.type_code = DataTypeCode.FLOAT
head = head[5:]
@@ -170,12 +178,6 @@ class DataType(ctypes.Structure):
elif head.startswith("bfloat"):
self.type_code = DataTypeCode.BFLOAT
head = head[6:]
- elif head.startswith("e4m3_float"):
- self.type_code = DataTypeCode.E4M3Float
- head = head[10:]
- elif head.startswith("e5m2_float"):
- self.type_code = DataTypeCode.E5M2Float
- head = head[10:]
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
@@ -204,7 +206,9 @@ class DataType(ctypes.Structure):
type_name = "custom[%s]" %
tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
if self.type_code in [
- DataTypeCode.FLOAT4E2M1FN,
+ DataTypeCode.Float8E4M3FN,
+ DataTypeCode.Float8E5M2,
+ DataTypeCode.Float4E2M1FN,
]:
x = type_name
else:
@@ -243,8 +247,8 @@ class DataType(ctypes.Structure):
if ml_dtypes is not None:
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
- DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
- DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
+ DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "float8_e4m3fn"
+ DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"
RPC_SESS_MASK = 128
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 9bff724df7..5a5a3a1c80 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -369,19 +369,19 @@ def load_ndarray_cache(cachepath: str, device:
tvm.runtime.Device):
arr = tvm.nd.empty(shape, dtype, device=device)
assert offset + nbytes <= len(raw_data)
buffer_source = raw_data[offset : offset + nbytes]
- if dtype == "e4m3_float8":
+ if dtype == "float8_e4m3fn":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
else:
raise RuntimeError(
- "ml_dtypes is not installed, cannot convert
e4m3_float8 array to numpy."
+ "ml_dtypes is not installed, cannot convert
float8_e4m3fn array to numpy."
)
- if dtype == "e5m2_float8":
+ if dtype == "float8_e5m2":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e5m2
else:
raise RuntimeError(
- "ml_dtypes is not installed, cannot convert
e5m2_float8 array to numpy."
+ "ml_dtypes is not installed, cannot convert
float8_e5m2 array to numpy."
)
if encode_format == "f32-to-bf16" and dtype == "float32":
data = np.frombuffer(buffer_source,
dtype="uint16").reshape(shape)
diff --git a/python/tvm/relax/backend/cuda/cublas.py
b/python/tvm/relax/backend/cuda/cublas.py
index 287b18b440..6828381e68 100644
--- a/python/tvm/relax/backend/cuda/cublas.py
+++ b/python/tvm/relax/backend/cuda/cublas.py
@@ -27,18 +27,18 @@ from tvm.relax.transform import PatternCheckContext
from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import (
- make_matmul_pattern,
make_matmul_dequantize_pattern,
make_matmul_multiply_pattern,
+ make_matmul_pattern,
)
from ..utils import has_leaking_intermediate_variables
def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
"""Check if dtypes in the given workload are supported by cuBLAS BYOC."""
- if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
- # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
- return out_dtype != "e5m2_float8"
+ if lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
+ # The output cannot be 'float8_e5m2' if inputs are 'float8_e4m3fn'
+ return out_dtype != "float8_e5m2"
return (
(lhs_dtype == "float16" and rhs_dtype == "float16")
or (lhs_dtype == "float32" and rhs_dtype == "float32")
@@ -83,7 +83,7 @@ def _check_matmul(context: PatternCheckContext) -> bool:
if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or
rhs_shape[-1] % 4 != 0:
# Rows number must be multiples of 4 for IGEMM
return False
- elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
+ elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
matmul_rhs_var = matmul_call.args[1]
rhs_transposed = False
if matmul_rhs_var in context.matched_bindings:
diff --git a/python/tvm/relax/backend/rocm/hipblas.py
b/python/tvm/relax/backend/rocm/hipblas.py
index c0accc1473..63c72b660d 100644
--- a/python/tvm/relax/backend/rocm/hipblas.py
+++ b/python/tvm/relax/backend/rocm/hipblas.py
@@ -30,9 +30,9 @@ from ..utils import has_leaking_intermediate_variables
def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): # pylint:
disable=unused-argument
"""Check if dtypes in the given workload are supported by hipblas BYOC."""
- if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
- # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
- # return out_dtype != "e5m2_float8"
+ if lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
+ # The output cannot be 'float8_e5m2' if inputs are 'float8_e4m3fn'
+ # return out_dtype != "float8_e5m2"
return False
return (lhs_dtype == "float16" and rhs_dtype == "float16") or (
lhs_dtype == "int8" and rhs_dtype == "int8"
@@ -61,7 +61,7 @@ def _check_matmul(context: PatternCheckContext) -> bool:
if lhs_dtype == "int8" and rhs_dtype == "int8":
return False
- elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
+ elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn":
return False
lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 47fcccf52b..d55334a154 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -249,19 +249,19 @@ class NDArray(NDArrayBase):
dtype = "int8"
if dtype == "bfloat16":
dtype = "uint16"
- if dtype == "e4m3_float8":
+ if dtype == "float8_e4m3fn":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
else:
raise RuntimeError(
- "ml_dtypes is not installed, cannot convert e4m3_float8
array to numpy."
+ "ml_dtypes is not installed, cannot convert float8_e4m3fn
array to numpy."
)
- if dtype == "e5m2_float8":
+ if dtype == "float8_e5m2":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e5m2
else:
raise RuntimeError(
- "ml_dtypes is not installed, cannot convert e5m2_float8
array to numpy."
+ "ml_dtypes is not installed, cannot convert float8_e5m2
array to numpy."
)
if dtype == "float4_e2m1fn":
if ml_dtypes is not None:
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index c35df7a093..2fce022da3 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1442,27 +1442,27 @@ float16x64 = func_gen(("Float16x64"))
float32x64 = func_gen(("Float32x64"))
float64x64 = func_gen(("Float64x64"))
-e4m3_float8 = func_gen(("E4M3Float8"))
-e4m3_float8x4 = func_gen(("E4M3Float8x4"))
-e4m3_float8x8 = func_gen(("E4M3Float8x8"))
-e4m3_float8x16 = func_gen(("E4M3Float8x16"))
-e4m3_float8x32 = func_gen(("E4M3Float8x32"))
-e4m3_float8x64 = func_gen(("E4M3Float8x64"))
-
-e5m2_float8 = func_gen(("E5M2Float8"))
-e5m2_float8x4 = func_gen(("E5M2Float8x4"))
-e5m2_float8x8 = func_gen(("E5M2Float8x8"))
-e5m2_float8x16 = func_gen(("E5M2Float8x16"))
-e5m2_float8x32 = func_gen(("E5M2Float8x32"))
-e5m2_float8x64 = func_gen(("E5M2Float8x64"))
-
-float4_e2m1fn = func_gen(("Float4E2M1fn"))
-float4_e2m1fnx2 = func_gen(("Float4E2M1fnx2"))
-float4_e2m1fnx4 = func_gen(("Float4E2M1fnx4"))
-float4_e2m1fnx8 = func_gen(("Float4E2M1fnx8"))
-float4_e2m1fnx16 = func_gen(("Float4E2M1fnx16"))
-float4_e2m1fnx32 = func_gen(("Float4E2M1fnx32"))
-float4_e2m1fnx64 = func_gen(("Float4E2M1fnx64"))
+float8_e4m3fn = func_gen(("Float8E4M3FN"))
+float8_e4m3fnx4 = func_gen(("Float8E4M3FNx4"))
+float8_e4m3fnx8 = func_gen(("Float8E4M3FNx8"))
+float8_e4m3fnx16 = func_gen(("Float8E4M3FNx16"))
+float8_e4m3fnx32 = func_gen(("Float8E4M3FNx32"))
+float8_e4m3fnx64 = func_gen(("Float8E4M3FNx64"))
+
+float8_e5m2 = func_gen(("Float8E5M2"))
+float8_e5m2x4 = func_gen(("Float8E5M2x4"))
+float8_e5m2x8 = func_gen(("Float8E5M2x8"))
+float8_e5m2x16 = func_gen(("Float8E5M2x16"))
+float8_e5m2x32 = func_gen(("Float8E5M2x32"))
+float8_e5m2x64 = func_gen(("Float8E5M2x64"))
+
+float4_e2m1fn = func_gen(("Float4E2M1FN"))
+float4_e2m1fnx2 = func_gen(("Float4E2M1FNx2"))
+float4_e2m1fnx4 = func_gen(("Float4E2M1FNx4"))
+float4_e2m1fnx8 = func_gen(("Float4E2M1FNx8"))
+float4_e2m1fnx16 = func_gen(("Float4E2M1FNx16"))
+float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32"))
+float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64"))
# pylint: enable=invalid-name
@@ -2011,39 +2011,39 @@ __all__ = [
"uint16x64",
"uint32x64",
"uint64x64",
- "e4m3_float8",
- "e5m2_float8",
+ "float8_e4m3fn",
+ "float8_e5m2",
"float4_e2m1fn",
"float16",
"float32",
"float64",
"float4_e2m1fnx2",
- "e4m3_float8x4",
- "e5m2_float8x4",
+ "float8_e4m3fnx4",
+ "float8_e5m2x4",
"float4_e2m1fnx4",
"float16x4",
"float32x4",
"float64x4",
- "e4m3_float8x8",
- "e5m2_float8x8",
+ "float8_e4m3fnx8",
+ "float8_e5m2x8",
"float4_e2m1fnx8",
"float16x8",
"float32x8",
"float64x8",
- "e4m3_float8x16",
- "e5m2_float8x16",
+ "float8_e4m3fnx16",
+ "float8_e5m2x16",
"float4_e2m1fnx16",
"float16x16",
"float32x16",
"float64x16",
- "e4m3_float8x32",
- "e5m2_float8x32",
+ "float8_e4m3fnx32",
+ "float8_e5m2x32",
"float4_e2m1fnx32",
"float16x32",
"float32x32",
"float64x32",
- "e4m3_float8x64",
- "e5m2_float8x64",
+ "float8_e4m3fnx64",
+ "float8_e5m2x64",
"float4_e2m1fnx64",
"float16x64",
"float32x64",
diff --git a/python/tvm/tir/tensor_intrin/cuda.py
b/python/tvm/tir/tensor_intrin/cuda.py
index e1ff18bc8f..57b1c3b873 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -16,13 +16,13 @@
# under the License.
# pylint: disable=invalid-name,missing-function-docstring,unused-variable
"""Intrinsics for tensorization on NVIDIA GPU."""
-from typing import Dict, Optional, Tuple, Literal
+from typing import Dict, Literal, Optional, Tuple
from tvm._ffi import register_func
from tvm.runtime import convert
from tvm.script import tir as T
-from tvm.tir.function import PrimFunc
from tvm.tir import Cast, IntImm, TensorIntrin
+from tvm.tir.function import PrimFunc
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
@@ -123,7 +123,7 @@ def get_ldmatrix_intrin(
matrix_name == "B" or not transposed
), "Now only B matrix can be transposed for int8 matmul"
assert k_dim == 32 and (
- dtype == "int8" or dtype == "e4m3_float8" or dtype == "e5m2_float8"
+ dtype == "int8" or dtype == "float8_e4m3fn" or dtype ==
"float8_e5m2"
), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now"
if matrix_name == "B" and not transposed:
@@ -261,25 +261,25 @@ LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_i8_b_trans"
TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32,
"int8", "B", True))
LDMATRIX_e4m3_A_INTRIN = "mma_ldmatrix_e4m3_a"
-TensorIntrin.register(LDMATRIX_e4m3_A_INTRIN, *get_ldmatrix_intrin(32,
"e4m3_float8", "A", False))
+TensorIntrin.register(LDMATRIX_e4m3_A_INTRIN, *get_ldmatrix_intrin(32,
"float8_e4m3fn", "A", False))
LDMATRIX_e4m3_B_INTRIN = "mma_ldmatrix_e4m3_b"
-TensorIntrin.register(LDMATRIX_e4m3_B_INTRIN, *get_ldmatrix_intrin(32,
"e4m3_float8", "B", False))
+TensorIntrin.register(LDMATRIX_e4m3_B_INTRIN, *get_ldmatrix_intrin(32,
"float8_e4m3fn", "B", False))
LDMATRIX_e4m3_B_TRANS_INTRIN = "mma_ldmatrix_e4m3_b_trans"
TensorIntrin.register(
- LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B",
True)
+ LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "float8_e4m3fn",
"B", True)
)
LDMATRIX_e5m2_A_INTRIN = "mma_ldmatrix_e5m2_a"
-TensorIntrin.register(LDMATRIX_e5m2_A_INTRIN, *get_ldmatrix_intrin(32,
"e5m2_float8", "A", False))
+TensorIntrin.register(LDMATRIX_e5m2_A_INTRIN, *get_ldmatrix_intrin(32,
"float8_e5m2", "A", False))
LDMATRIX_e5m2_B_INTRIN = "mma_ldmatrix_e5m2_b"
-TensorIntrin.register(LDMATRIX_e5m2_B_INTRIN, *get_ldmatrix_intrin(32,
"e5m2_float8", "B", False))
+TensorIntrin.register(LDMATRIX_e5m2_B_INTRIN, *get_ldmatrix_intrin(32,
"float8_e5m2", "B", False))
LDMATRIX_e5m2_B_TRANS_INTRIN = "mma_ldmatrix_e5m2_b_trans"
TensorIntrin.register(
- LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B",
True)
+ LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "float8_e5m2", "B",
True)
)
@@ -315,8 +315,8 @@ def get_mma_intrin(
"float32": "fp32",
"int8": "int8",
"int32": "int32",
- "e4m3_float8": "e4m3",
- "e5m2_float8": "e5m2",
+ "float8_e4m3fn": "e4m3",
+ "float8_e5m2": "e5m2",
}
a_dtype_abbrv = dtype_abbrv[a_dtype]
b_dtype_abbrv = dtype_abbrv[b_dtype]
@@ -522,25 +522,25 @@ TensorIntrin.register(
MMA_e5m2e5m2f32_INTRIN = "mma_e5m2e5m2f32"
TensorIntrin.register(
MMA_e5m2e5m2f32_INTRIN,
- *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, False),
+ *get_mma_intrin(32, "float8_e5m2", "float8_e5m2", "float32", False, False),
)
MMA_e5m2e5m2f32_TRANS_B_INTRIN = "mma_e5m2e5m2f32_trans_b"
TensorIntrin.register(
MMA_e5m2e5m2f32_TRANS_B_INTRIN,
- *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, True),
+ *get_mma_intrin(32, "float8_e5m2", "float8_e5m2", "float32", False, True),
)
MMA_e4m3e4m3f32_INTRIN = "mma_e4m3e4m3f32"
TensorIntrin.register(
MMA_e4m3e4m3f32_INTRIN,
- *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, False),
+ *get_mma_intrin(32, "float8_e4m3fn", "float8_e4m3fn", "float32", False,
False),
)
MMA_e4m3e4m3f32_TRANS_B_INTRIN = "mma_e4m3e4m3f32_trans_b"
TensorIntrin.register(
MMA_e4m3e4m3f32_TRANS_B_INTRIN,
- *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, True),
+ *get_mma_intrin(32, "float8_e4m3fn", "float8_e4m3fn", "float32", False,
True),
)
@@ -705,7 +705,7 @@ TensorIntrin.register(
def get_mma_intrin_group(
load_scope: Literal["shared", "shared.dyn"],
store_scope: Literal["global", "shared", "shared.dyn"],
- in_dtype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"],
+ in_dtype: Literal["float16", "int8", "float8_e4m3fn", "float8_e5m2"],
out_dtype: Literal["float16", "float32", "int32"],
trans_a: bool,
trans_b: bool,
@@ -752,7 +752,7 @@ def get_mma_intrin_group(
"""
assert load_scope in ["shared", "shared.dyn"]
assert store_scope in ["global", "shared", "shared.dyn"]
- assert in_dtype in ["float16", "int8", "e4m3_float8", "e5m2_float8"]
+ assert in_dtype in ["float16", "int8", "float8_e4m3fn", "float8_e5m2"]
assert out_dtype in ["float16", "float32", "int32"]
shape = "16x16"
@@ -761,8 +761,8 @@ def get_mma_intrin_group(
"float16": "f16",
"float32": "f32",
"int8": "i8",
- "e4m3_float8": "e4m3",
- "e5m2_float8": "e5m2",
+ "float8_e4m3fn": "e4m3",
+ "float8_e5m2": "e5m2",
"int32": "i32",
}
a_dtype = dtype_mapping[in_dtype]
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 766abf3483..8f188e95f0 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -132,15 +132,16 @@ FloatImm::FloatImm(DataType dtype, double value, Span
span) {
ICHECK_LE(value, support::kMaxBFloat16)
<< "ValueError: Literal value " << value << " exceeds maximum of "
<< dtype;
} else if (dtype.is_float8()) {
- double bound = (dtype.code() == DataType::kE4M3Float) ?
support::kMaxE4M3 : support::kMaxE5M2;
+ double bound =
+ (dtype.code() == DataType::kFloat8_e4m3fn) ? support::kMaxE4M3FN :
support::kMaxE5M2;
ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << "
exceeds minimum of "
<< dtype;
ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << "
exceeds maximum of "
<< dtype;
} else if (dtype.is_float4()) {
- ICHECK_GE(value, -support::kMaxE2M1)
+ ICHECK_GE(value, -support::kMaxE2M1FN)
<< "ValueError: Literal value " << value << " exceeds minimum of "
<< dtype;
- ICHECK_LE(value, support::kMaxE2M1)
+ ICHECK_LE(value, support::kMaxE2M1FN)
<< "ValueError: Literal value " << value << " exceeds maximum of "
<< dtype;
}
}
diff --git a/src/relax/backend/contrib/utils.h
b/src/relax/backend/contrib/utils.h
index aa3928ce02..e63e99548c 100644
--- a/src/relax/backend/contrib/utils.h
+++ b/src/relax/backend/contrib/utils.h
@@ -72,10 +72,12 @@ inline std::string DType2String(const tvm::DataType dtype) {
std::ostringstream os;
if (dtype.is_float()) {
os << "float";
- } else if (dtype.is_e4m3_float8()) {
- os << "e4m3_float";
- } else if (dtype.is_e5m2_float8()) {
- os << "e5m2_float";
+ } else if (dtype.is_float8_e4m3fn()) {
+ return "float8_e4m3fn";
+ } else if (dtype.is_float8_e5m2()) {
+ return "float8_e5m2";
+ } else if (dtype.is_float4_e2m1fn()) {
+ return "float4_e2m1fn";
} else if (dtype.is_int()) {
os << "int";
} else if (dtype.is_uint()) {
diff --git a/src/runtime/contrib/cublas/cublas.cc
b/src/runtime/contrib/cublas/cublas.cc
index c9a01fc24e..ba01f791d9 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -164,8 +164,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
ab_type = CUDA_R_16F;
} else if (TypeMatch(A->dtype, kDLInt, 8)) {
ab_type = CUDA_R_8I;
- } else if (TypeMatch(A->dtype, DataType::TypeCode::kE4M3Float, 8)) {
- ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kE4M3Float, 8));
+ } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) {
+ ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8));
ab_type = CUDA_R_8F_E4M3;
}
@@ -217,7 +217,6 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
int N = RowCount(A, transa, batch_offset_A);
int K = ColumnCount(A, transa, batch_offset_A);
bool use_batched_gemm = A->ndim > 2 || B->ndim > 2;
-
// If A is batched but B is not, flatten all non-reduction axes of A to use
the regular GEMM.
// This trick is only applicable if batch axes and the other spatial axis (M
or N) are
// adjacent in both the input and the output matrix. In particular, if A is
of shape (M, K)
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index d876065325..5a328413a1 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -54,7 +54,7 @@ inline void VerifyDataType(DLDataType dtype) {
return;
else if (dtype.bits == 4 && dtype.code == kDLInt)
return;
- else if (dtype.bits == 4 && dtype.code == DataType::kFloat4E2M1Fn)
+ else if (dtype.bits == 4 && dtype.code == DataType::kFloat4_e2m1fn)
return;
else
ICHECK_EQ(dtype.bits % 8, 0);
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index a73c9cb5b4..a75a357810 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -752,13 +752,13 @@
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.E4M3Float8").set_body_typed(E4M3Float8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float8);
-TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
-TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1fn").set_body_typed(Float4E2M1fn);
-TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1fn", Float4E2M1fn);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1FN").set_body_typed(Float4E2M1FN);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
diff --git a/src/support/scalars.h b/src/support/scalars.h
index b229a6b338..adc449ffd6 100644
--- a/src/support/scalars.h
+++ b/src/support/scalars.h
@@ -63,14 +63,14 @@ constexpr double kMaxBFloat16 =
3.895313892515354759047080037148786688e38;
// 2^8 * (1 + 6/8)
// See https://arxiv.org/pdf/2209.05433.pdf
-constexpr double kMaxE4M3 = 448;
+constexpr double kMaxE4M3FN = 448;
// 2^15 * (1 + 3/4)
// See https://arxiv.org/pdf/2209.05433.pdf
constexpr double kMaxE5M2 = 57344;
// 2^2 * (1 + 1/2)
-constexpr double kMaxE2M1 = 6.0;
+constexpr double kMaxE2M1FN = 6.0;
} // namespace support
} // namespace tvm
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 9c2ce0bbb2..ead0bdff3c 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -579,9 +579,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType&
dtype) const {
default:
LOG(FATAL) << "do not support " << dtype;
}
- } else if (dtype.code() == DataType::kE4M3Float || dtype.code() ==
DataType::kE5M2Float) {
+ } else if (dtype.code() == DataType::kFloat8_e4m3fn || dtype.code() ==
DataType::kFloat8_e5m2) {
etype = llvm::Type::getInt8Ty(*ctx);
- } else if (dtype.code() == DataType::kFloat4E2M1Fn) {
+ } else if (dtype.code() == DataType::kFloat4_e2m1fn) {
etype = llvm::Type::getIntNTy(*ctx, 4);
}
if (!dtype.is_scalar()) {
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 20b29750dc..35973776c8 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -60,9 +60,9 @@ std::string GetFP8Type(DataType type) {
}
stream << "__nv_fp8";
std::string suffix;
- if (type.code() == DataType::kE4M3Float) {
+ if (type.code() == DataType::kFloat8_e4m3fn) {
suffix = "_e4m3";
- } else if (type.code() == DataType::kE5M2Float) {
+ } else if (type.code() == DataType::kFloat8_e5m2) {
suffix = "_e5m2";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
@@ -86,7 +86,7 @@ std::string GetFP4Type(DataType type) {
}
stream << "__nv_fp4";
std::string suffix;
- if (type.code() == DataType::kFloat4E2M1Fn) {
+ if (type.code() == DataType::kFloat4_e2m1fn) {
suffix = "_e2m1";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
@@ -159,9 +159,7 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#endif\n\n";
decl_stream << "#include <cuda.h>\n";
- decl_stream << "#if (CUDA_VERSION <12080)\n";
decl_stream << _cuda_half_util;
- decl_stream << "#endif\n";
}
if (enable_bf16_) {
@@ -734,9 +732,9 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op,
std::ostream& os) {
// Emit simple C-style type conversion.
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
- if (target_ty.code() == DataType::kE4M3Float || target_ty.code() ==
DataType::kE5M2Float ||
- target_ty.code() == DataType::kFloat4E2M1Fn || from_ty.code() ==
DataType::kE4M3Float ||
- from_ty.code() == DataType::kE5M2Float || from_ty.code() ==
DataType::kFloat4E2M1Fn) {
+ if (target_ty.code() == DataType::kFloat8_e4m3fn || target_ty.code() ==
DataType::kFloat8_e5m2 ||
+ target_ty.code() == DataType::kFloat4_e2m1fn || from_ty.code() ==
DataType::kFloat8_e4m3fn ||
+ from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() ==
DataType::kFloat4_e2m1fn) {
std::ostringstream val;
val << "(";
PrintType(target_ty, val);
@@ -1508,7 +1506,7 @@ inline void PrintConst(const FloatImmNode* op,
std::ostream& os, CodeGenCUDA* p)
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
- // Type code is kE5M2Float or kE4M4Float
+ // Type code is kFloat8_e5m2 or kE4M4Float
if (op->dtype.is_float8() || op->dtype.is_float4()) {
p->PrintType(op->dtype, os);
os << '(' << std::scientific << op->value << 'f' << ')';
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 3dab634f16..63c82d1d6c 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -273,9 +273,9 @@ PrimExpr max_value(const DataType& dtype, Span span) {
return FloatImm(dtype, std::numeric_limits<float>::max(), span);
} else if (dtype.is_float8()) {
// according to https://arxiv.org/pdf/2209.05433.pdf
- if (dtype.code() == DataType::TypeCode::kE5M2Float) {
+ if (dtype.code() == DataType::TypeCode::kFloat8_e5m2) {
return FloatImm(dtype, 57344.0, span);
- } else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
+ } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3fn) {
return FloatImm(dtype, 448.0, span);
}
} else if (dtype.is_float4()) {
@@ -316,9 +316,9 @@ PrimExpr min_value(const DataType& dtype, Span span) {
return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
} else if (dtype.is_float8()) {
// according to https://arxiv.org/pdf/2209.05433.pdf
- if (dtype.code() == DataType::TypeCode::kE5M2Float) {
+ if (dtype.code() == DataType::TypeCode::kFloat8_e5m2) {
return FloatImm(dtype, -57344.0, span);
- } else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
+ } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3fn) {
return FloatImm(dtype, -448.0, span);
}
} else if (dtype.is_float4()) {
diff --git a/src/tir/transforms/dtype_conversion.h
b/src/tir/transforms/dtype_conversion.h
index 8edbf1bc1e..a0ed6b5f6d 100644
--- a/src/tir/transforms/dtype_conversion.h
+++ b/src/tir/transforms/dtype_conversion.h
@@ -121,7 +121,7 @@ class FloatConfig {
// NVIDIA/Arm/Intel's FP8 formats for Deep Learning
// Reference: https://arxiv.org/abs/2209.05433
switch (dtype.code()) {
- case DataType::kE4M3Float:
+ case DataType::kFloat8_e4m3fn:
// E4M3 format, not consistent with IEEE-754
return FloatConfig(4, 3, 7, InftyStyle::kNone, NaNStyle::kAllOnes);
default:
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py
b/tests/python/codegen/test_target_codegen_cuda_fp8.py
index d94153003c..b91efd6192 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fp8.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py
@@ -16,25 +16,24 @@
# under the License.
import sys
+from typing import List, Tuple
+
+import numpy as np
import pytest
import tvm
-from tvm.script import tir as T
-import numpy as np
import tvm.testing
-
-
-from typing import List, Tuple
from tvm import DataType, DataTypeCode, IRModule
from tvm import dlight as dl
from tvm import relax, te, tir, topi
from tvm.relax.frontend import nn
from tvm.runtime import NDArray
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
from tvm.target import Target
from tvm.topi.utils import get_const_tuple
-from tvm.script import ir as I, relax as R, tir as T
-
try:
import ml_dtypes
except ImportError:
@@ -43,7 +42,7 @@ except ImportError:
@tvm.testing.requires_cuda_compute_version(8, 9)
def test_e4m3_conversions():
- dtype = "e4m3_float8"
+ dtype = "float8_e4m3fn"
@T.prim_func
def add(
@@ -90,7 +89,7 @@ def test_e4m3_conversions():
def test_e4m3_packing():
length = 64
vector_length = 4
- native_dtype, packed_dtype = ("e4m3_float8x4", "uint32")
+ native_dtype, packed_dtype = ("float8_e4m3fnx4", "uint32")
@T.prim_func
def add(
@@ -141,13 +140,13 @@ def test_e4m3_packing():
native_dtype, promoted_dtype = tvm.testing.parameters(
- ("e4m3_float8", "float32"),
- ("e4m3_float8", "float16"),
- ("e4m3_float8x2", "float32x2"),
- ("e4m3_float8x2", "float16x2"),
- ("e4m3_float8x4", "float32x4"),
+ ("float8_e4m3fn", "float32"),
+ ("float8_e4m3fn", "float16"),
+ ("float8_e4m3fnx2", "float32x2"),
+ ("float8_e4m3fnx2", "float16x2"),
+ ("float8_e4m3fnx4", "float32x4"),
# Supported via half4 vector type extension in codegen
- ("e4m3_float8x4", "float16x4"),
+ ("float8_e4m3fnx4", "float16x4"),
)
@@ -343,7 +342,7 @@ class BaseFP8E4M3QuantScaleOnly:
axis,
output_transpose,
) -> IRModule:
- if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float:
+ if DataType(quantize_dtype).type_code == DataTypeCode.Float8E4M3FN:
quantize_func = cls.quantize_fp8x4_e4m3
else:
assert NotImplementedError()
@@ -387,7 +386,7 @@ class BaseFP8E4M3QuantScaleOnly:
num_elem_per_storage,
axis,
) -> IRModule:
- if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float:
+ if DataType(quantize_dtype).type_code == DataTypeCode.Float8E4M3FN:
dequantize_func = cls.dequantize_fp8x4_e4m3
else:
assert NotImplementedError()
@@ -732,7 +731,7 @@ class
TestFP8e4x4QuantDequantScale(BaseFP8E4M3QuantScaleOnly):
@tvm.testing.fixture
def quantize_dtype(self):
- return "e4m3_float8"
+ return "float8_e4m3fn"
@tvm.testing.fixture
def num_el_per_storage(self):
@@ -807,7 +806,7 @@ class
TestFP8e4x4QuantDequantScale(BaseFP8E4M3QuantScaleOnly):
@tvm.testing.requires_cuda_compute_version(8, 9)
[email protected]("dtype", ["e5m2_float8", "e4m3_float8"])
[email protected]("dtype", ["float8_e5m2", "float8_e4m3fn"])
def test_const(dtype):
@T.prim_func
def func(A: T.Buffer((4,), dtype)) -> None:
@@ -822,7 +821,7 @@ def test_const(dtype):
@tvm.testing.requires_cuda_compute_version(8, 9)
[email protected]("dtype", ["e5m2_float8", "e4m3_float8"])
[email protected]("dtype", ["float8_e5m2", "float8_e4m3fn"])
@pytest.mark.parametrize("vec_len", [2, 4, 8, 16])
def test_copy(dtype, vec_len):
@T.prim_func
@@ -867,7 +866,7 @@ def test_moe_gemv_shfl_down_illegal_instr():
@T.prim_func(private=True)
def moe_dequantize_gemv(
x_handle: T.handle,
- w: T.Buffer((num_experts, spatial_size, reduce_size),
"e4m3_float8"),
+ w: T.Buffer((num_experts, spatial_size, reduce_size),
"float8_e4m3fn"),
scale: T.Buffer((1,), "float16"),
indptr: T.Buffer((1, 2), "int32"),
o: T.Buffer((2, spatial_size), "float16"),
@@ -905,7 +904,7 @@ def test_moe_gemv_shfl_down_illegal_instr():
def main(
x: R.Tensor(("num_seq", reduce_size), dtype="float16"),
indptr: R.Tensor((1, 2), dtype="int32"),
- weight: R.Tensor((num_experts, spatial_size, reduce_size),
dtype="e4m3_float8"),
+ weight: R.Tensor((num_experts, spatial_size, reduce_size),
dtype="float8_e4m3fn"),
scale: R.Tensor((1,), dtype="float32"),
) -> R.Tensor((2, spatial_size), dtype="float16"):
num_seq = T.int64()
@@ -965,4 +964,5 @@ def test_moe_gemv_shfl_down_illegal_instr():
if __name__ == "__main__":
+ # test_half_broadcast(6)
tvm.testing.main()
diff --git a/tests/python/ir/test_datatype_nv_fp8.py
b/tests/python/ir/test_datatype_nv_fp8.py
index 8313a97ee1..b812c70ab6 100644
--- a/tests/python/ir/test_datatype_nv_fp8.py
+++ b/tests/python/ir/test_datatype_nv_fp8.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
+
import tvm
import tvm.testing
import tvm.tir as tir
@@ -22,9 +23,10 @@ from tvm import te
from tvm.script import tir as T
try:
- from ml_dtypes import float8_e4m3fn as e4m3_float8, float8_e5m2 as
e5m2_float8
+ from ml_dtypes import float8_e4m3fn as float8_e4m3fn
+ from ml_dtypes import float8_e5m2 as float8_e5m2
except ImportError:
- e4m3_float8, e5m2_float8 = None, None
+ float8_e4m3fn, float8_e5m2 = None, None
def fp8_unary(dtype: str):
@@ -58,7 +60,7 @@ def fp8_unary(dtype: str):
np_dtype, dtype_str = tvm.testing.parameters(
- (e4m3_float8, "e4m3_float8"), (e5m2_float8, "e5m2_float8")
+ (float8_e4m3fn, "float8_e4m3fn"), (float8_e5m2, "float8_e5m2")
)
diff --git a/tests/python/ir/test_dtype.py b/tests/python/ir/test_dtype.py
index 77cd1d7e4b..988e360748 100644
--- a/tests/python/ir/test_dtype.py
+++ b/tests/python/ir/test_dtype.py
@@ -15,22 +15,23 @@
# specific language governing permissions and limitations
# under the License.
"""Test data type related API"""
+import pytest
+
import tvm
-from tvm import DataType
import tvm.testing
-import pytest
+from tvm import DataType
@pytest.mark.parametrize(
"dtype_str, expected_size",
- [("float32", 4), ("float32x4", 16), ("e5m2_float8x4", 4), ("uint8", 1)],
+ [("float32", 4), ("float32x4", 16), ("float8_e5m2x4", 4), ("uint8", 1)],
)
def test_dtype_itemsize(dtype_str, expected_size):
dtype = DataType(dtype_str)
assert dtype.itemsize() == expected_size
[email protected]("dtype_str", [("int32xvscalex4")])
[email protected]("dtype_str", ["int32xvscalex4"])
def test_dtype_itemmize_error(dtype_str):
with pytest.raises(ValueError):
size = DataType(dtype_str).itemsize()
diff --git a/tests/python/relax/test_codegen_cublas.py
b/tests/python/relax/test_codegen_cublas.py
index 2fbff8433b..c5514e2727 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -315,7 +315,7 @@ def test_matmul_fp8_offload(
transpose_y,
out_dtype,
):
- in_dtype = "e4m3_float8"
+ in_dtype = "float8_e4m3fn"
mod = get_relax_matmul_module(
x_shape,
y_shape,
@@ -342,7 +342,7 @@ def test_matmul_fp8_offload(
def test_matmul_fp8_dequantize_offload():
x_shape = (10, 32)
y_shape = (64, 32)
- in_dtype = "e4m3_float8"
+ in_dtype = "float8_e4m3fn"
mod = get_relax_matmul_dequantize_module(
x_shape,
y_shape,
@@ -369,7 +369,7 @@ def test_matmul_fp8_multiply_offload():
x_shape = (10, 32)
y_shape = (64, 32)
z_shape = (1,)
- in_dtype, acc_dtype = ("e4m3_float8", "float32")
+ in_dtype, acc_dtype = ("float8_e4m3fn", "float32")
mod = get_relax_matmul_multiply_module(
x_shape,
@@ -397,8 +397,8 @@ def test_matmul_fp8_multiply_offload():
"M, N, K, out_dtype, transposed_y, partition_done",
[
(15, 64, 32, "float32", True, True),
- (15, 64, 32, "e4m3_float8", True, True),
- (15, 64, 32, "e5m2_float8", True, False),
+ (15, 64, 32, "float8_e4m3fn", True, True),
+ (15, 64, 32, "float8_e5m2", True, False),
(16, 32, 60, "float32", True, False),
(16, 30, 64, "float32", True, False),
(16, 8, 16, "float16", True, True),
@@ -407,7 +407,7 @@ def test_matmul_fp8_multiply_offload():
)
def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, transposed_y,
partition_done):
mod = get_relax_matmul_module(
- (M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=transposed_y
+ (M, K), (N, K), "float8_e4m3fn", out_dtype, transposed_y=transposed_y
)
mod = partition_for_cublas(mod)
func_name = "relax_matmul_cublas" if partition_done else "R.matmul"
@@ -426,7 +426,7 @@ def test_cublas_partition_fp8_matmul_dequantize(M, N, K,
scale, zp, num_bindings
mod = get_relax_matmul_dequantize_module(
(M, K),
(N, K),
- "e4m3_float8",
+ "float8_e4m3fn",
"float16",
transposed_y=True,
scale_const=scale,
@@ -443,7 +443,7 @@ def test_cublas_partition_fp8_matmul_multiply():
(M, K),
(N, K),
(1,),
- "e4m3_float8",
+ "float8_e4m3fn",
"float32",
"float16",
transposed_y=True,
diff --git a/tests/python/relax/test_op_inspect.py
b/tests/python/relax/test_op_inspect.py
index 18d7a88f05..ca4b0fc440 100644
--- a/tests/python/relax/test_op_inspect.py
+++ b/tests/python/relax/test_op_inspect.py
@@ -21,10 +21,10 @@ import numpy as np
import pytest
import tvm.testing
-
from tvm import relax
from tvm.ir import Op
-from tvm.script import ir as I, relax as R
+from tvm.script import ir as I
+from tvm.script import relax as R
# Parameterization for reading dtype of DLTensor. Chosen to have
# multiple distinct type codes, number of lanes, and widths.
@@ -34,7 +34,7 @@ dtype = tvm.testing.parameter(
"float32",
"float32x4",
"bfloat",
- "e4m3_float8",
+ "float8_e4m3fn",
)
shape = tvm.testing.parameter(
[],
diff --git a/tests/python/relax/test_op_qdq.py
b/tests/python/relax/test_op_qdq.py
index 8b2d499041..d773a6c7d2 100644
--- a/tests/python/relax/test_op_qdq.py
+++ b/tests/python/relax/test_op_qdq.py
@@ -68,17 +68,17 @@ def test_qdq_op_infer_struct_info_symbolic():
)
-def test_qdq_e4m3_float8_op_infer_struct_info_symbolic():
+def test_qdq_float8_e4m3fn_op_infer_struct_info_symbolic():
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
x = relax.Var("x", R.Tensor((n, 3), "float32"))
- dx = relax.Var("dx", R.Tensor((n, 3), "e4m3_float8"))
+ dx = relax.Var("dx", R.Tensor((n, 3), "float8_e4m3fn"))
s = relax.Var("s", R.Tensor([3], "float32"))
zp = relax.Var("zp", R.Tensor([3], "float16"))
_check_inference(
bb,
- relax.op.quantize(x, s, zp, 1, "e4m3_float8"),
- relax.TensorStructInfo((n, 3), "e4m3_float8"),
+ relax.op.quantize(x, s, zp, 1, "float8_e4m3fn"),
+ relax.TensorStructInfo((n, 3), "float8_e4m3fn"),
)
_check_inference(
bb,
@@ -87,8 +87,8 @@ def test_qdq_e4m3_float8_op_infer_struct_info_symbolic():
)
-def test_qdq_e5m2_float8_op_infer_struct_info_symbolic():
- dtype = "e5m2_float8"
+def test_qdq_float8_e5m2_op_infer_struct_info_symbolic():
+ dtype = "float8_e5m2"
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
x = relax.Var("x", R.Tensor((n, 3), "float32"))
diff --git
a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
index fe9998bc79..5a80e3e4f6 100644
---
a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
+++
b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
@@ -17,21 +17,24 @@
# pylint: disable=missing-docstring
import numpy as np
import pytest
+
import tvm
import tvm.testing
from tvm import te
from tvm.testing.tir import mma_schedule
from tvm.tir.tensor_intrin.cuda import (
+ LDMATRIX_e4m3_A_INTRIN,
+ LDMATRIX_e4m3_B_TRANS_INTRIN,
+ LDMATRIX_e5m2_A_INTRIN,
+ LDMATRIX_e5m2_B_TRANS_INTRIN,
LDMATRIX_f16_A_INTRIN,
LDMATRIX_f16_B_INTRIN,
LDMATRIX_f16_B_TRANS_INTRIN,
LDMATRIX_i8_A_INTRIN,
- LDMATRIX_i8_B_TRANS_INTRIN,
LDMATRIX_i8_B_INTRIN,
- LDMATRIX_e4m3_A_INTRIN,
- LDMATRIX_e4m3_B_TRANS_INTRIN,
- LDMATRIX_e5m2_A_INTRIN,
- LDMATRIX_e5m2_B_TRANS_INTRIN,
+ LDMATRIX_i8_B_TRANS_INTRIN,
+ MMA_e4m3e4m3f32_TRANS_B_INTRIN,
+ MMA_e5m2e5m2f32_TRANS_B_INTRIN,
MMA_f16f16f16_INTRIN,
MMA_f16f16f16_TRANS_B_INTRIN,
MMA_f16f16f32_INTRIN,
@@ -41,8 +44,6 @@ from tvm.tir.tensor_intrin.cuda import (
MMA_fill_16x16_i32_INTRIN,
MMA_i8i8i32_INTRIN,
MMA_i8i8i32_TRANS_B_INTRIN,
- MMA_e5m2e5m2f32_TRANS_B_INTRIN,
- MMA_e4m3e4m3f32_TRANS_B_INTRIN,
MMA_store_16x16_f16_global_INTRIN,
MMA_store_16x16_f32_global_INTRIN,
MMA_store_16x16_i32_global_INTRIN,
@@ -132,10 +133,10 @@ def run_test(
else:
b_np = np.random.normal(size=(K, N)).astype("float16")
c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32")).astype(out_dtype)
- elif in_dtype in ["e4m3_float8", "e5m2_float8"]:
+ elif in_dtype in ["float8_e4m3fn", "float8_e5m2"]:
typemap = {
- "e4m3_float8": "float8_e4m3fn",
- "e5m2_float8": "float8_e5m2",
+ "float8_e4m3fn": "float8_e4m3fn",
+ "float8_e5m2": "float8_e5m2",
}
a_np = (
np.random.uniform(low=-5, high=5, size=(M * K))
@@ -174,7 +175,7 @@ def run_test(
f(a, b, c)
- if out_dtype != "float16" and in_dtype not in ["e4m3_float8",
"e5m2_float8"]:
+ if out_dtype != "float16" and in_dtype not in ["float8_e4m3fn",
"float8_e5m2"]:
# The numpy reference is computed with fp32 precision (otherwise too
slow).
# So there is non-trivial accuracy difference if TVM result is
computed with fp16 accumulation.
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2)
@@ -384,7 +385,7 @@ def test_e4m3e4m3f32_m16n16k32():
)
k_inner = 32
- in_dtype = "e4m3_float8"
+ in_dtype = "float8_e4m3fn"
out_dtype = "float32"
i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32,
2, 2]
@@ -427,7 +428,7 @@ def test_e5m2e5m2f32_m16n16k32():
)
k_inner = 32
- in_dtype = "e5m2_float8"
+ in_dtype = "float8_e5m2"
out_dtype = "float32"
i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32,
2, 2]
diff --git a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
index e1f487c572..0b10fe5c21 100644
--- a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
+++ b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
@@ -17,8 +17,8 @@
import tvm
import tvm.script
import tvm.testing
-from tvm.target import Target
from tvm.script import tir as T
+from tvm.target import Target
from tvm.tir.transform.transform import BindTarget
# pylint: disable=no-member,invalid-name,unused-variable
@@ -69,7 +69,7 @@ def get_after_compute_legalize(dtype: str, promote_dtype:
str):
def promote_uint8(f8_dtype: str, promote_dtype: str, v):
- if f8_dtype == "e4m3_float8":
+ if f8_dtype == "float8_e4m3fn":
if promote_dtype == "float16":
mantissa = T.bitwise_and(
T.shift_left(T.Cast("uint16", v), T.uint16(7)), T.uint16(0x3FF)
@@ -96,7 +96,7 @@ def promote_uint8(f8_dtype: str, promote_dtype: str, v):
)
sign = T.shift_left(T.Cast("uint32", T.shift_right(v,
T.uint8(7))), T.uint32(31))
return T.reinterpret("float32",
T.bitwise_or(T.bitwise_or(mantissa, exponent), sign))
- else: # f8_dtype == "e5m2_float8"
+ else: # f8_dtype == "float8_e5m2"
if promote_dtype == "float16":
return T.reinterpret("float16", T.shift_left(T.Cast("uint16", v),
T.uint16(8)))
else: # promote_dtype == "float32"
@@ -115,7 +115,7 @@ def promote_uint8(f8_dtype: str, promote_dtype: str, v):
def cast_to_uint8(f8_dtype: str, promote_dtype: str, v):
- if f8_dtype == "e4m3_float8":
+ if f8_dtype == "float8_e4m3fn":
if promote_dtype == "float16":
uint16_v = T.reinterpret("uint16", v)
rounding_bias = T.bitwise_and(
@@ -154,7 +154,7 @@ def cast_to_uint8(f8_dtype: str, promote_dtype: str, v):
return T.if_then_else(
round_to_zero, T.uint8(0), T.bitwise_or(T.bitwise_or(mantissa,
exponent), sign)
)
- else: # f8_dtype == "e5m2_float8"
+ else: # f8_dtype == "float8_e5m2"
if promote_dtype == "float16":
uint16_v = T.reinterpret("uint16", v)
rounding_bias = T.bitwise_and(
@@ -201,12 +201,12 @@ def get_after_storage_legalize(dtype: str, promote_dtype:
str):
return After
-dtype = tvm.testing.parameter("e4m3_float8", "e5m2_float8")
+dtype = tvm.testing.parameter("float8_e4m3fn", "float8_e5m2")
promote_dtype = tvm.testing.parameter("float16", "float32")
def test_fp8_compute_legalize(dtype, promote_dtype):
- target = Target("cuda")
+ target = Target("nvidia/nvidia-a100")
before = BindTarget(target)(get_before(dtype))
expected = BindTarget(target)(get_after_compute_legalize(dtype,
promote_dtype))
# run the transform twice to ensure we can afford to deal
@@ -217,7 +217,7 @@ def test_fp8_compute_legalize(dtype, promote_dtype):
def test_fp8_storage_legalize(dtype, promote_dtype):
- target = Target("cuda")
+ target = Target("nvidia/nvidia-a100")
before = BindTarget(target)(get_after_compute_legalize(dtype,
promote_dtype))
after = tvm.tir.transform.FP8StorageLegalize()(before)
expected = BindTarget(target)(get_after_storage_legalize(dtype,
promote_dtype))
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index b7ba57fa93..943ba54060 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -17,6 +17,7 @@
# pylint: disable=missing-docstring
import re
+
import pytest
import tvm.testing
@@ -917,23 +918,23 @@ def func():
_assert_print(func, expected_output)
[email protected]("dtype", ["e4m3_float8", "e5m2_float8"])
[email protected]("dtype", ["float8_e4m3fn", "float8_e5m2"])
def test_float8(dtype):
from tvm.script import tir as T
def get_func(dtype):
- if dtype == "e4m3_float8":
+ if dtype == "float8_e4m3fn":
@T.prim_func
def func():
- T.evaluate(T.e4m3_float8(0.0))
+ T.evaluate(T.float8_e4m3fn(0.0))
return func
- elif dtype == "e5m2_float8":
+ elif dtype == "float8_e5m2":
@T.prim_func
def func():
- T.evaluate(T.e5m2_float8(0.0))
+ T.evaluate(T.float8_e5m2(0.0))
return func