This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 59bfb21559 [CodeGen][CUDA] Move fast math intrinsic lowering option to
PassContext (#19596)
59bfb21559 is described below
commit 59bfb2155981672f12754f4c29f3a1f1c0055fe1
Author: Shushi Hong <[email protected]>
AuthorDate: Sun May 24 10:30:00 2026 -0400
[CodeGen][CUDA] Move fast math intrinsic lowering option to PassContext
(#19596)
This updates CUDA fast math intrinsic lowering to use a PassContext
option instead of a CUDA Target attribute.
The new option is:
```python
with tvm.transform.PassContext(config={"tirx.enable_fast_math": True}):
...
```
When unset or false, CUDA math intrinsics continue to lower to the
precise CUDA math functions such as expf. When true, tirx.LowerIntrin
prioritizes the cuda.fastmath.* lowering rules, producing fast math
intrinsics such as __expf.
---
python/tvm/target/detect_target.py | 1 -
python/tvm/target/tag_registry/cuda.py | 5 +----
src/target/target_kind.cc | 9 ---------
src/tirx/ir/transform.cc | 1 +
src/tirx/transform/lower_intrin.cc | 18 +++++++++++-------
.../codegen/test_target_codegen_cuda_fastmath.py | 13 ++++++++++---
tests/python/target/test_target_target.py | 14 +-------------
7 files changed, 24 insertions(+), 37 deletions(-)
diff --git a/python/tvm/target/detect_target.py
b/python/tvm/target/detect_target.py
index f7d79ba434..81accfed12 100644
--- a/python/tvm/target/detect_target.py
+++ b/python/tvm/target/detect_target.py
@@ -41,7 +41,6 @@ def _detect_cuda(dev: Device) -> Target:
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
"arch": "sm_" + dev.compute_version.replace(".", ""),
- "enable_fast_math": False,
}
)
diff --git a/python/tvm/target/tag_registry/cuda.py
b/python/tvm/target/tag_registry/cuda.py
index d3740cb515..6b1bd9e8a8 100644
--- a/python/tvm/target/tag_registry/cuda.py
+++ b/python/tvm/target/tag_registry/cuda.py
@@ -28,14 +28,12 @@ def _register_cuda_tag(name, arch, shared_mem=49152,
regs=65536, **extra):
"max_threads_per_block": 1024,
"thread_warp_size": 32,
"registers_per_block": regs,
- # Default to disable fast math
- "enable_fast_math": False,
}
config.update(extra)
register_tag(name, config)
-def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536,
enable_fast_math=False):
+def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536):
register_tag(
name,
{
@@ -51,7 +49,6 @@ def _register_jetson_tag(name, arch, mcpu, num_cores,
regs=65536, enable_fast_ma
"mcpu": mcpu,
"num-cores": num_cores,
},
- "enable_fast_math": enable_fast_math,
},
)
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index d6a8d30c4f..5779b4da0e 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -188,14 +188,6 @@ ffi::Map<ffi::String, ffi::Any>
UpdateCUDAAttrs(ffi::Map<ffi::String, ffi::Any>
target.Set("arch", ffi::String("sm_") + std::to_string(archInt));
}
}
- // Update enable_fast_math
- if (target.count("enable_fast_math")) {
- // If enable_fast_math has been specified, validate that enable_fast_math
is a bool
- Downcast<bool>(target.at("enable_fast_math"));
- } else {
- // If enable_fast_math has not been specified, default to false
- target.Set("enable_fast_math", false);
- }
return target;
}
@@ -380,7 +372,6 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
.add_attr_option<int64_t>("l2_cache_size_bytes")
.add_attr_option<int64_t>("max_num_threads",
refl::DefaultValue(1024)) // TODO(@zxybazh):
deprecate it
- .add_attr_option<bool>("enable_fast_math")
.set_default_keys({"cuda", "gpu"})
.set_target_canonicalizer(UpdateCUDAAttrs);
diff --git a/src/tirx/ir/transform.cc b/src/tirx/ir/transform.cc
index ac651410d9..d336d05726 100644
--- a/src/tirx/ir/transform.cc
+++ b/src/tirx/ir/transform.cc
@@ -48,6 +48,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tirx.merge_static_smem",
Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tirx.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tirx.vtcm_capacity", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tirx.ptx_ldg32", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tirx.enable_fast_math", Bool);
/*!
* \brief Function level pass that applies transformations to all
diff --git a/src/tirx/transform/lower_intrin.cc
b/src/tirx/transform/lower_intrin.cc
index 7f4b1aa30b..a580e33fdd 100644
--- a/src/tirx/transform/lower_intrin.cc
+++ b/src/tirx/transform/lower_intrin.cc
@@ -46,15 +46,15 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitStmt_;
using FLowerGeneral = ffi::TypedFunction<PrimExpr(PrimExpr)>;
- IntrinInjecter(arith::Analyzer* analyzer, const Target& tgt) :
IRMutatorWithAnalyzer(analyzer) {
+ IntrinInjecter(arith::Analyzer* analyzer, const Target& tgt, bool
enable_fast_math)
+ : IRMutatorWithAnalyzer(analyzer) {
std::string target = tgt->kind->name;
ffi::String mtriple = tgt->GetAttr<ffi::String>("mtriple").value_or("");
std::vector<std::string> patterns;
- // For CUDA targets, we need to add the fast math patterns if
enable_fast_math is true.
- // The priority of the fast math patterns is higher than the normal
patterns.
- bool is_fast_math = tgt->GetAttr<bool>("enable_fast_math").value_or(false);
- if (is_fast_math) {
+ // Add the fast math patterns when requested. The priority of the fast
math
+ // patterns is higher than the normal patterns.
+ if (enable_fast_math) {
patterns.push_back(target + ".fastmath.FLowerIntrinsic");
patterns.push_back(target + ".fastmath.FLegalize");
}
@@ -364,7 +364,10 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
arith::Analyzer analyzer;
- return IntrinInjecter(&analyzer,
Target(ffi::String(target)))(std::move(stmt));
+ bool enable_fast_math = transform::PassContext::Current()
+ ->GetConfig<Bool>("tirx.enable_fast_math",
Bool(false))
+ .value();
+ return IntrinInjecter(&analyzer, Target(ffi::String(target)),
enable_fast_math)(std::move(stmt));
}
namespace transform {
@@ -375,7 +378,8 @@ Pass LowerIntrin() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
TVM_FFI_ICHECK(target.defined()) << "LowerIntrin: Require the target
attribute";
arith::Analyzer analyzer;
- n->body = IntrinInjecter(&analyzer, target.value())(std::move(n->body));
+ bool enable_fast_math = ctx->GetConfig<Bool>("tirx.enable_fast_math",
Bool(false)).value();
+ n->body = IntrinInjecter(&analyzer, target.value(),
enable_fast_math)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tirx.LowerIntrin", {});
diff --git a/tests/python/codegen/test_target_codegen_cuda_fastmath.py
b/tests/python/codegen/test_target_codegen_cuda_fastmath.py
index 84cac4361e..a3a9d4a308 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fastmath.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fastmath.py
@@ -203,7 +203,7 @@ def make_mod(
dtype: str, case: MathCase, enable_fast_math: bool
) -> tuple[tvm.target.Target, tvm.IRModule]:
"""Make a module for the given dtype and case."""
- target = tvm.target.Target({"kind": "cuda", "enable_fast_math":
enable_fast_math})
+ target = tvm.target.Target("cuda")
prim_func = make_prim_func(case.name, dtype, case.num_inputs, case.op)
return target, tvm.IRModule.from_expr(prim_func.with_attr("target",
target))
@@ -227,7 +227,8 @@ def check_lowered_ir(
) -> tuple[tvm.target.Target, IRModule]:
"""Check the lowered IR for the given dtype and case."""
target, mod = make_mod(dtype, case, enable_fast_math)
- lowered_mod = tvm.tirx.transform.LowerIntrin()(mod)
+ with tvm.transform.PassContext(config={"tirx.enable_fast_math":
enable_fast_math}):
+ lowered_mod = tvm.tirx.transform.LowerIntrin()(mod)
script = lowered_mod.script(show_meta=False)
expected = expected_intrinsic(dtype, case, enable_fast_math)
assert re.search(rf"""["']{re.escape(expected)}["']""", script)
@@ -242,7 +243,8 @@ def check_cuda_source(
enable_fast_math: bool,
) -> Executable:
"""Check the CUDA source for the given dtype and case."""
- executable = tvm.compile(mod, target=target)
+ with tvm.transform.PassContext(config={"tirx.enable_fast_math":
enable_fast_math}):
+ executable = tvm.compile(mod, target=target)
source = executable.mod.imports[0].inspect_source()
expected = expected_intrinsic(dtype, case, enable_fast_math)
assert re.search(rf"(?<!_)\b{re.escape(expected)}\s*\(", source)
@@ -279,6 +281,11 @@ def check_runtime(dtype: str, case: MathCase, executable:
Executable):
np.testing.assert_allclose(actual, expected, rtol=case.rtol,
atol=case.atol)
[email protected]("enable_fast_math", [False, True], ids=["default",
"fast_math"])
+def test_cuda_math_intrinsic_lowering_pass_context(enable_fast_math):
+ check_lowered_ir("float32", MATH_CASES[0], enable_fast_math)
+
+
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
@pytest.mark.parametrize(
diff --git a/tests/python/target/test_target_target.py
b/tests/python/target/test_target_target.py
index 94706ee8d8..c037fcadd2 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -148,7 +148,6 @@ def test_target_tag_0():
assert tgt.attrs["max_threads_per_block"] == 1024
assert tgt.attrs["thread_warp_size"] == 32
assert tgt.attrs["registers_per_block"] == 65536
- assert not tgt.attrs["enable_fast_math"]
def test_target_tag_1():
@@ -159,19 +158,15 @@ def test_target_tag_1():
assert tgt.attrs["max_threads_per_block"] == 1024
assert tgt.attrs["thread_warp_size"] == 32
assert tgt.attrs["registers_per_block"] == 32768
- assert not tgt.attrs["enable_fast_math"]
def test_target_tag_override():
"""Test creating a target from a tag with attribute overrides."""
- tgt = tvm.target.Target(
- {"tag": "nvidia/nvidia-a100", "l2_cache_size_bytes": 12345,
"enable_fast_math": True}
- )
+ tgt = tvm.target.Target({"tag": "nvidia/nvidia-a100",
"l2_cache_size_bytes": 12345})
assert tgt.kind.name == "cuda"
assert tgt.attrs["arch"] == "sm_80"
# Override should take effect
assert int(tgt.attrs["l2_cache_size_bytes"]) == 12345
- assert tgt.attrs["enable_fast_math"]
# Base tag fields should be preserved
assert tgt.attrs["max_shared_memory_per_block"] == 49152
assert tgt.attrs["thread_warp_size"] == 32
@@ -194,14 +189,12 @@ def test_target_host_tags():
assert tgt.attrs["max_threads_per_block"] == 1024
assert tgt.attrs["thread_warp_size"] == 32
assert tgt.attrs["registers_per_block"] == 32768
- assert not tgt.attrs["enable_fast_math"]
assert tgt.host.kind.name == "cuda"
assert tgt.host.attrs["arch"] == "sm_75"
assert tgt.host.attrs["max_shared_memory_per_block"] == 49152
assert tgt.host.attrs["max_threads_per_block"] == 1024
assert tgt.host.attrs["thread_warp_size"] == 32
assert tgt.host.attrs["registers_per_block"] == 65536
- assert not tgt.host.attrs["enable_fast_math"]
def test_target_host_tag_dict():
@@ -212,7 +205,6 @@ def test_target_host_tag_dict():
assert tgt.attrs["max_threads_per_block"] == 1024
assert tgt.attrs["thread_warp_size"] == 32
assert tgt.attrs["registers_per_block"] == 32768
- assert not tgt.attrs["enable_fast_math"]
assert tgt.host.kind.name == "llvm"
@@ -225,7 +217,6 @@ def test_target_host_single_dict():
assert tgt.host.attrs["max_threads_per_block"] == 1024
assert tgt.host.attrs["thread_warp_size"] == 32
assert tgt.host.attrs["registers_per_block"] == 32768
- assert not tgt.host.attrs["enable_fast_math"]
def test_target_host_single_string():
@@ -243,7 +234,6 @@ def test_target_host_single_string_with_tag():
assert tgt.host.attrs["max_threads_per_block"] == 1024
assert tgt.host.attrs["thread_warp_size"] == 32
assert tgt.host.attrs["registers_per_block"] == 32768
- assert not tgt.host.attrs["enable_fast_math"]
def test_target_host_merge_0():
@@ -255,7 +245,6 @@ def test_target_host_merge_0():
assert tgt.host.attrs["max_threads_per_block"] == 1024
assert tgt.host.attrs["thread_warp_size"] == 32
assert tgt.host.attrs["registers_per_block"] == 32768
- assert not tgt.host.attrs["enable_fast_math"]
def test_target_host_merge_1():
@@ -306,7 +295,6 @@ def test_target_with_host():
assert tgt.host.attrs["max_threads_per_block"] == 1024
assert tgt.host.attrs["thread_warp_size"] == 32
assert tgt.host.attrs["registers_per_block"] == 32768
- assert not tgt.host.attrs["enable_fast_math"]
def test_target_attr_bool_value():