This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2829b59e1c [TVMScript] Add parser and printer support for e4m3/e5m2
fp8 (#16864)
2829b59e1c is described below
commit 2829b59e1c78796da273b650f006628bca64cfcc
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Apr 10 05:22:41 2024 -0700
[TVMScript] Add parser and printer support for e4m3/e5m2 fp8 (#16864)
* [TVMScript] Add parser and printer support for e4m3/e5m2 fp8
* remove unrelated
---
include/tvm/script/ir_builder/tir/ir.h | 12 +++++++
python/tvm/script/ir_builder/tir/ir.py | 39 +++++++++++++++-------
src/script/ir_builder/tir/ir.cc | 5 +++
.../python/tvmscript/test_tvmscript_printer_tir.py | 31 +++++++++++++++++
4 files changed, 75 insertions(+), 12 deletions(-)
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index 735d5ba6c0..c4ba44f673 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -489,6 +489,18 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int,
DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8,
DataType::NVFloat8E4M3);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8,
DataType::NVFloat8E5M2);
+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index a5c09cf1a3..127d2a4356 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1408,30 +1408,39 @@ uint16x64 = func_gen(("UInt16x64"))
uint32x64 = func_gen(("UInt32x64"))
uint64x64 = func_gen(("UInt64x64"))
-float8 = func_gen(("Float8"))
float16 = func_gen(("Float16"))
float32 = func_gen(("Float32"))
float64 = func_gen(("Float64"))
-float8x4 = func_gen(("Float8x4"))
float16x4 = func_gen(("Float16x4"))
float32x4 = func_gen(("Float32x4"))
float64x4 = func_gen(("Float64x4"))
-float8x8 = func_gen(("Float8x8"))
float16x8 = func_gen(("Float16x8"))
float32x8 = func_gen(("Float32x8"))
float64x8 = func_gen(("Float64x8"))
-float8x16 = func_gen(("Float8x16"))
float16x16 = func_gen(("Float16x16"))
float32x16 = func_gen(("Float32x16"))
float64x16 = func_gen(("Float64x16"))
-float8x32 = func_gen(("Float8x32"))
float16x32 = func_gen(("Float16x32"))
float32x32 = func_gen(("Float32x32"))
float64x32 = func_gen(("Float64x32"))
-float8x64 = func_gen(("Float8x64"))
float16x64 = func_gen(("Float16x64"))
float32x64 = func_gen(("Float32x64"))
float64x64 = func_gen(("Float64x64"))
+
+e4m3_float8 = func_gen(("E4M3Float8"))
+e4m3_float8x4 = func_gen(("E4M3Float8x4"))
+e4m3_float8x8 = func_gen(("E4M3Float8x8"))
+e4m3_float8x16 = func_gen(("E4M3Float8x16"))
+e4m3_float8x32 = func_gen(("E4M3Float8x32"))
+e4m3_float8x64 = func_gen(("E4M3Float8x64"))
+
+e5m2_float8 = func_gen(("E5M2Float8"))
+e5m2_float8x4 = func_gen(("E5M2Float8x4"))
+e5m2_float8x8 = func_gen(("E5M2Float8x8"))
+e5m2_float8x16 = func_gen(("E5M2Float8x16"))
+e5m2_float8x32 = func_gen(("E5M2Float8x32"))
+e5m2_float8x64 = func_gen(("E5M2Float8x64"))
+
# pylint: enable=invalid-name
@@ -1954,27 +1963,33 @@ __all__ = [
"uint16x64",
"uint32x64",
"uint64x64",
- "float8",
+ "e4m3_float8",
+ "e5m2_float8",
"float16",
"float32",
"float64",
- "float8x4",
+ "e4m3_float8x4",
+ "e5m2_float8x4",
"float16x4",
"float32x4",
"float64x4",
- "float8x8",
+ "e4m3_float8x8",
+ "e5m2_float8x8",
"float16x8",
"float32x8",
"float64x8",
- "float8x16",
+ "e4m3_float8x16",
+ "e5m2_float8x16",
"float16x16",
"float32x16",
"float64x16",
- "float8x32",
+ "e4m3_float8x32",
+ "e5m2_float8x32",
"float16x32",
"float32x32",
"float64x32",
- "float8x64",
+ "e4m3_float8x64",
+ "e5m2_float8x64",
"float16x64",
"float32x64",
"float64x64",
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 1ae1051d25..ccb5a8b57b 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -751,6 +751,11 @@
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.E4M3Float8").set_body_typed(E4M3Float8);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float8);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);
+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index 97a6b889c0..edc6da3163 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -917,5 +917,36 @@ def func():
_assert_print(func, expected_output)
[email protected]("dtype", ["e4m3_float8", "e5m2_float8"])
+def test_float8(dtype):
+ from tvm.script import tir as T
+
+ def get_func(dtype):
+ if dtype == "e4m3_float8":
+
+ @T.prim_func
+ def func():
+ T.evaluate(T.e4m3_float8(0.0))
+
+ return func
+ elif dtype == "e5m2_float8":
+
+ @T.prim_func
+ def func():
+ T.evaluate(T.e5m2_float8(0.0))
+
+ return func
+
+ expected_output = f"""
+# from tvm.script import tir as T
+
[email protected]_func
+def func():
+ T.evaluate(T.{dtype}(0))
+ """
+ func = get_func(dtype)
+ _assert_print(func, expected_output)
+
+
if __name__ == "__main__":
tvm.testing.main()