This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cb08f0d57b [TIR][Driver] Use `BindTarget` to specify target for FP8
legalization (#16767)
cb08f0d57b is described below
commit cb08f0d57b5098a6edadad18ee058523087d81f1
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Sun Mar 24 20:26:35 2024 -0400
[TIR][Driver] Use `BindTarget` to specify target for FP8 legalization
(#16767)
* Do not pass target explicitly to FP8 legalization, use BindTarget instead
* Lint: Remove unused import
* Add comment on pass ordering
---
include/tvm/tir/transform.h | 8 ++++----
python/tvm/tir/transform/transform.py | 18 +++++-------------
src/driver/driver_api.cc | 8 ++++----
src/tir/transforms/unsupported_dtype_legalize.cc | 6 ++++--
.../tir-transform/test_tir_transform_fp8_legalize.py | 15 ++++++++-------
5 files changed, 25 insertions(+), 30 deletions(-)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index e219cc6846..98edbeaceb 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -398,7 +398,6 @@ TVM_DLL Pass ForceNarrowIndexToInt32();
/*!
* \brief Legalize bf16 compute Ops. Add a cast to fp32
* before Ops, then add a cast back to bf16.
- * \param target The target used for checking native bf16 support
* \return The pass.
*/
TVM_DLL Pass BF16ComputeLegalize();
@@ -406,11 +405,11 @@ TVM_DLL Pass BF16ComputeLegalize();
/*!
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
* before Ops, then add a cast back to fp8.
- * \param target The target used for checking native fp8 support
* \param promote_dtype_str The data type used for type promotion, defaults to
float16
+ * \note Must be run after BindTarget, as it relies on target attributes for
PrimFuncs
* \return The pass.
*/
-TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str =
"float16");
+TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
/*!
* \brief Legalize bf16 storage types to u16.
@@ -420,9 +419,10 @@ TVM_DLL Pass BF16StorageLegalize();
/*!
* \brief Legalize fp8 storage types to u8.
+ * \note Must be run after BindTarget, as it relies on target attributes for
PrimFuncs
* \return The pass.
*/
-TVM_DLL Pass FP8StorageLegalize(Target target);
+TVM_DLL Pass FP8StorageLegalize();
/*!
* \brief Inline calls to private functions
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 9f7f92dbed..c2022b9186 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -19,7 +19,7 @@
import enum
-from typing import Any, Callable, Optional
+from typing import Callable, Optional
from . import _ffi_api
from . import function_pass as _fpass
@@ -323,7 +323,7 @@ def BF16ComputeLegalize():
return _ffi_api.BF16ComputeLegalize() # type: ignore
-def FP8ComputeLegalize(target: Any, promote_dtype_str: str = "float32"):
+def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
"""Legalize fp8 compute Ops.
Parameters
@@ -331,15 +331,12 @@ def FP8ComputeLegalize(target: Any, promote_dtype_str:
str = "float32"):
promote_dtype : str
The data type we promote fp8 to, options: float16/float32.
- target : tvm.target.Target
- The legalization target
-
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
- return _ffi_api.FP8ComputeLegalize(target, promote_dtype_str) # type:
ignore
+ return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore
def BF16StorageLegalize():
@@ -353,20 +350,15 @@ def BF16StorageLegalize():
return _ffi_api.BF16StorageLegalize() # type: ignore
-def FP8StorageLegalize(target: Any):
+def FP8StorageLegalize():
"""Legalize fp8 storage types to u8.
- Parameters
- ----------
- target : tvm.target.Target
- The legalization target
-
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
- return _ffi_api.FP8StorageLegalize(target) # type: ignore
+ return _ffi_api.FP8StorageLegalize() # type: ignore
def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms:
bool = False):
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 33b4514e6b..7ea5032fa0 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -569,15 +569,15 @@ transform::Sequential MixedModulePassManager(IRModule
mixed_mod, Target target)
Array<Pass> mixed_pass_list;
- mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target));
+ // FPComputeLegalize uses the target attrs added by BindTarget, so it must
come first
+ mixed_pass_list.push_back(tir::transform::BindTarget(target));
+ mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize());
// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory
allocation locations
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());
- mixed_pass_list.push_back(tir::transform::BindTarget(target));
-
mixed_pass_list.push_back(tir::transform::VerifyMemory());
mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
@@ -620,7 +620,7 @@ transform::Sequential MixedModulePassManager(IRModule
mixed_mod, Target target)
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
- mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target));
+ mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc
b/src/tir/transforms/unsupported_dtype_legalize.cc
index c037879074..5537c8a409 100644
--- a/src/tir/transforms/unsupported_dtype_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -727,8 +727,9 @@ Pass BF16StorageLegalize() {
TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize);
-Pass FP8ComputeLegalize(Target target, String promote_dtype_str) {
+Pass FP8ComputeLegalize(String promote_dtype_str) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
return f;
}
@@ -739,8 +740,9 @@ Pass FP8ComputeLegalize(Target target, String
promote_dtype_str) {
TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize);
-Pass FP8StorageLegalize(Target target) {
+Pass FP8StorageLegalize() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
return f;
}
diff --git a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
index 6e44b53d0c..e1f487c572 100644
--- a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
+++ b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
@@ -19,6 +19,7 @@ import tvm.script
import tvm.testing
from tvm.target import Target
from tvm.script import tir as T
+from tvm.tir.transform.transform import BindTarget
# pylint: disable=no-member,invalid-name,unused-variable
@@ -206,20 +207,20 @@ promote_dtype = tvm.testing.parameter("float16",
"float32")
def test_fp8_compute_legalize(dtype, promote_dtype):
target = Target("cuda")
- before = get_before(dtype)
- expected = get_after_compute_legalize(dtype, promote_dtype)
+ before = BindTarget(target)(get_before(dtype))
+ expected = BindTarget(target)(get_after_compute_legalize(dtype,
promote_dtype))
# run the transform twice to ensure we can afford to deal
# with this repeative optimizations
- after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(before)
- after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(after)
+ after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(before)
+ after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(after)
tvm.ir.assert_structural_equal(after, expected)
def test_fp8_storage_legalize(dtype, promote_dtype):
target = Target("cuda")
- before = get_after_compute_legalize(dtype, promote_dtype)
- after = tvm.tir.transform.FP8StorageLegalize(target)(before)
- expected = get_after_storage_legalize(dtype, promote_dtype)
+ before = BindTarget(target)(get_after_compute_legalize(dtype,
promote_dtype))
+ after = tvm.tir.transform.FP8StorageLegalize()(before)
+ expected = BindTarget(target)(get_after_storage_legalize(dtype,
promote_dtype))
tvm.ir.assert_structural_equal(after, expected)