This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 59bfb21559 [CodeGen][CUDA] Move fast math intrinsic lowering option to 
PassContext (#19596)
59bfb21559 is described below

commit 59bfb2155981672f12754f4c29f3a1f1c0055fe1
Author: Shushi Hong <[email protected]>
AuthorDate: Sun May 24 10:30:00 2026 -0400

    [CodeGen][CUDA] Move fast math intrinsic lowering option to PassContext 
(#19596)
    
    This updates CUDA fast math intrinsic lowering to use a PassContext
    option instead of a CUDA Target attribute.
    
    The new option is:
    
    ```python
    with tvm.transform.PassContext(config={"tirx.enable_fast_math": True}):
        ...
    ```
    
    When unset or false, CUDA math intrinsics continue to lower to the
    precise CUDA math functions such as expf. When true, tirx.LowerIntrin
    prioritizes the cuda.fastmath.* lowering rules, producing fast math
    intrinsics such as __expf.
---
 python/tvm/target/detect_target.py                     |  1 -
 python/tvm/target/tag_registry/cuda.py                 |  5 +----
 src/target/target_kind.cc                              |  9 ---------
 src/tirx/ir/transform.cc                               |  1 +
 src/tirx/transform/lower_intrin.cc                     | 18 +++++++++++-------
 .../codegen/test_target_codegen_cuda_fastmath.py       | 13 ++++++++++---
 tests/python/target/test_target_target.py              | 14 +-------------
 7 files changed, 24 insertions(+), 37 deletions(-)

diff --git a/python/tvm/target/detect_target.py 
b/python/tvm/target/detect_target.py
index f7d79ba434..81accfed12 100644
--- a/python/tvm/target/detect_target.py
+++ b/python/tvm/target/detect_target.py
@@ -41,7 +41,6 @@ def _detect_cuda(dev: Device) -> Target:
             "max_threads_per_block": dev.max_threads_per_block,
             "thread_warp_size": dev.warp_size,
             "arch": "sm_" + dev.compute_version.replace(".", ""),
-            "enable_fast_math": False,
         }
     )
 
diff --git a/python/tvm/target/tag_registry/cuda.py 
b/python/tvm/target/tag_registry/cuda.py
index d3740cb515..6b1bd9e8a8 100644
--- a/python/tvm/target/tag_registry/cuda.py
+++ b/python/tvm/target/tag_registry/cuda.py
@@ -28,14 +28,12 @@ def _register_cuda_tag(name, arch, shared_mem=49152, 
regs=65536, **extra):
         "max_threads_per_block": 1024,
         "thread_warp_size": 32,
         "registers_per_block": regs,
-        # Default to disable fast math
-        "enable_fast_math": False,
     }
     config.update(extra)
     register_tag(name, config)
 
 
-def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536, 
enable_fast_math=False):
+def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536):
     register_tag(
         name,
         {
@@ -51,7 +49,6 @@ def _register_jetson_tag(name, arch, mcpu, num_cores, 
regs=65536, enable_fast_ma
                 "mcpu": mcpu,
                 "num-cores": num_cores,
             },
-            "enable_fast_math": enable_fast_math,
         },
     )
 
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index d6a8d30c4f..5779b4da0e 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -188,14 +188,6 @@ ffi::Map<ffi::String, ffi::Any> 
UpdateCUDAAttrs(ffi::Map<ffi::String, ffi::Any>
       target.Set("arch", ffi::String("sm_") + std::to_string(archInt));
     }
   }
-  // Update enable_fast_math
-  if (target.count("enable_fast_math")) {
-    // If enable_fast_math has been specified, validate that enable_fast_math 
is a bool
-    Downcast<bool>(target.at("enable_fast_math"));
-  } else {
-    // If enable_fast_math has not been specified, default to false
-    target.Set("enable_fast_math", false);
-  }
   return target;
 }
 
@@ -380,7 +372,6 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
     .add_attr_option<int64_t>("l2_cache_size_bytes")
     .add_attr_option<int64_t>("max_num_threads",
                               refl::DefaultValue(1024))  // TODO(@zxybazh): 
deprecate it
-    .add_attr_option<bool>("enable_fast_math")
     .set_default_keys({"cuda", "gpu"})
     .set_target_canonicalizer(UpdateCUDAAttrs);
 
diff --git a/src/tirx/ir/transform.cc b/src/tirx/ir/transform.cc
index ac651410d9..d336d05726 100644
--- a/src/tirx/ir/transform.cc
+++ b/src/tirx/ir/transform.cc
@@ -48,6 +48,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tirx.merge_static_smem", 
Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tirx.instrument_lwp", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tirx.vtcm_capacity", Integer);
 TVM_REGISTER_PASS_CONFIG_OPTION("tirx.ptx_ldg32", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tirx.enable_fast_math", Bool);
 
 /*!
  * \brief Function level pass that applies transformations to all
diff --git a/src/tirx/transform/lower_intrin.cc 
b/src/tirx/transform/lower_intrin.cc
index 7f4b1aa30b..a580e33fdd 100644
--- a/src/tirx/transform/lower_intrin.cc
+++ b/src/tirx/transform/lower_intrin.cc
@@ -46,15 +46,15 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
   using IRMutatorWithAnalyzer::VisitStmt_;
   using FLowerGeneral = ffi::TypedFunction<PrimExpr(PrimExpr)>;
 
-  IntrinInjecter(arith::Analyzer* analyzer, const Target& tgt) : 
IRMutatorWithAnalyzer(analyzer) {
+  IntrinInjecter(arith::Analyzer* analyzer, const Target& tgt, bool 
enable_fast_math)
+      : IRMutatorWithAnalyzer(analyzer) {
     std::string target = tgt->kind->name;
     ffi::String mtriple = tgt->GetAttr<ffi::String>("mtriple").value_or("");
 
     std::vector<std::string> patterns;
-    // For CUDA targets, we need to add the fast math patterns if 
enable_fast_math is true.
-    // The priority of the fast math patterns is higher than the normal 
patterns.
-    bool is_fast_math = tgt->GetAttr<bool>("enable_fast_math").value_or(false);
-    if (is_fast_math) {
+    // Add the fast math patterns when requested.  The priority of the fast 
math
+    // patterns is higher than the normal patterns.
+    if (enable_fast_math) {
       patterns.push_back(target + ".fastmath.FLowerIntrinsic");
       patterns.push_back(target + ".fastmath.FLegalize");
     }
@@ -364,7 +364,10 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
 
 Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
   arith::Analyzer analyzer;
-  return IntrinInjecter(&analyzer, 
Target(ffi::String(target)))(std::move(stmt));
+  bool enable_fast_math = transform::PassContext::Current()
+                              ->GetConfig<Bool>("tirx.enable_fast_math", 
Bool(false))
+                              .value();
+  return IntrinInjecter(&analyzer, Target(ffi::String(target)), 
enable_fast_math)(std::move(stmt));
 }
 
 namespace transform {
@@ -375,7 +378,8 @@ Pass LowerIntrin() {
     auto target = f->GetAttr<Target>(tvm::attr::kTarget);
     TVM_FFI_ICHECK(target.defined()) << "LowerIntrin: Require the target 
attribute";
     arith::Analyzer analyzer;
-    n->body = IntrinInjecter(&analyzer, target.value())(std::move(n->body));
+    bool enable_fast_math = ctx->GetConfig<Bool>("tirx.enable_fast_math", 
Bool(false)).value();
+    n->body = IntrinInjecter(&analyzer, target.value(), 
enable_fast_math)(std::move(n->body));
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tirx.LowerIntrin", {});
diff --git a/tests/python/codegen/test_target_codegen_cuda_fastmath.py 
b/tests/python/codegen/test_target_codegen_cuda_fastmath.py
index 84cac4361e..a3a9d4a308 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fastmath.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fastmath.py
@@ -203,7 +203,7 @@ def make_mod(
     dtype: str, case: MathCase, enable_fast_math: bool
 ) -> tuple[tvm.target.Target, tvm.IRModule]:
     """Make a module for the given dtype and case."""
-    target = tvm.target.Target({"kind": "cuda", "enable_fast_math": 
enable_fast_math})
+    target = tvm.target.Target("cuda")
     prim_func = make_prim_func(case.name, dtype, case.num_inputs, case.op)
     return target, tvm.IRModule.from_expr(prim_func.with_attr("target", 
target))
 
@@ -227,7 +227,8 @@ def check_lowered_ir(
 ) -> tuple[tvm.target.Target, IRModule]:
     """Check the lowered IR for the given dtype and case."""
     target, mod = make_mod(dtype, case, enable_fast_math)
-    lowered_mod = tvm.tirx.transform.LowerIntrin()(mod)
+    with tvm.transform.PassContext(config={"tirx.enable_fast_math": 
enable_fast_math}):
+        lowered_mod = tvm.tirx.transform.LowerIntrin()(mod)
     script = lowered_mod.script(show_meta=False)
     expected = expected_intrinsic(dtype, case, enable_fast_math)
     assert re.search(rf"""["']{re.escape(expected)}["']""", script)
@@ -242,7 +243,8 @@ def check_cuda_source(
     enable_fast_math: bool,
 ) -> Executable:
     """Check the CUDA source for the given dtype and case."""
-    executable = tvm.compile(mod, target=target)
+    with tvm.transform.PassContext(config={"tirx.enable_fast_math": 
enable_fast_math}):
+        executable = tvm.compile(mod, target=target)
     source = executable.mod.imports[0].inspect_source()
     expected = expected_intrinsic(dtype, case, enable_fast_math)
     assert re.search(rf"(?<!_)\b{re.escape(expected)}\s*\(", source)
@@ -279,6 +281,11 @@ def check_runtime(dtype: str, case: MathCase, executable: 
Executable):
     np.testing.assert_allclose(actual, expected, rtol=case.rtol, 
atol=case.atol)
 
 
[email protected]("enable_fast_math", [False, True], ids=["default", 
"fast_math"])
+def test_cuda_math_intrinsic_lowering_pass_context(enable_fast_math):
+    check_lowered_ir("float32", MATH_CASES[0], enable_fast_math)
+
+
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 @pytest.mark.parametrize(
diff --git a/tests/python/target/test_target_target.py 
b/tests/python/target/test_target_target.py
index 94706ee8d8..c037fcadd2 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -148,7 +148,6 @@ def test_target_tag_0():
     assert tgt.attrs["max_threads_per_block"] == 1024
     assert tgt.attrs["thread_warp_size"] == 32
     assert tgt.attrs["registers_per_block"] == 65536
-    assert not tgt.attrs["enable_fast_math"]
 
 
 def test_target_tag_1():
@@ -159,19 +158,15 @@ def test_target_tag_1():
     assert tgt.attrs["max_threads_per_block"] == 1024
     assert tgt.attrs["thread_warp_size"] == 32
     assert tgt.attrs["registers_per_block"] == 32768
-    assert not tgt.attrs["enable_fast_math"]
 
 
 def test_target_tag_override():
     """Test creating a target from a tag with attribute overrides."""
-    tgt = tvm.target.Target(
-        {"tag": "nvidia/nvidia-a100", "l2_cache_size_bytes": 12345, 
"enable_fast_math": True}
-    )
+    tgt = tvm.target.Target({"tag": "nvidia/nvidia-a100", 
"l2_cache_size_bytes": 12345})
     assert tgt.kind.name == "cuda"
     assert tgt.attrs["arch"] == "sm_80"
     # Override should take effect
     assert int(tgt.attrs["l2_cache_size_bytes"]) == 12345
-    assert tgt.attrs["enable_fast_math"]
     # Base tag fields should be preserved
     assert tgt.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.attrs["thread_warp_size"] == 32
@@ -194,14 +189,12 @@ def test_target_host_tags():
     assert tgt.attrs["max_threads_per_block"] == 1024
     assert tgt.attrs["thread_warp_size"] == 32
     assert tgt.attrs["registers_per_block"] == 32768
-    assert not tgt.attrs["enable_fast_math"]
     assert tgt.host.kind.name == "cuda"
     assert tgt.host.attrs["arch"] == "sm_75"
     assert tgt.host.attrs["max_shared_memory_per_block"] == 49152
     assert tgt.host.attrs["max_threads_per_block"] == 1024
     assert tgt.host.attrs["thread_warp_size"] == 32
     assert tgt.host.attrs["registers_per_block"] == 65536
-    assert not tgt.host.attrs["enable_fast_math"]
 
 
 def test_target_host_tag_dict():
@@ -212,7 +205,6 @@ def test_target_host_tag_dict():
     assert tgt.attrs["max_threads_per_block"] == 1024
     assert tgt.attrs["thread_warp_size"] == 32
     assert tgt.attrs["registers_per_block"] == 32768
-    assert not tgt.attrs["enable_fast_math"]
     assert tgt.host.kind.name == "llvm"
 
 
@@ -225,7 +217,6 @@ def test_target_host_single_dict():
     assert tgt.host.attrs["max_threads_per_block"] == 1024
     assert tgt.host.attrs["thread_warp_size"] == 32
     assert tgt.host.attrs["registers_per_block"] == 32768
-    assert not tgt.host.attrs["enable_fast_math"]
 
 
 def test_target_host_single_string():
@@ -243,7 +234,6 @@ def test_target_host_single_string_with_tag():
     assert tgt.host.attrs["max_threads_per_block"] == 1024
     assert tgt.host.attrs["thread_warp_size"] == 32
     assert tgt.host.attrs["registers_per_block"] == 32768
-    assert not tgt.host.attrs["enable_fast_math"]
 
 
 def test_target_host_merge_0():
@@ -255,7 +245,6 @@ def test_target_host_merge_0():
     assert tgt.host.attrs["max_threads_per_block"] == 1024
     assert tgt.host.attrs["thread_warp_size"] == 32
     assert tgt.host.attrs["registers_per_block"] == 32768
-    assert not tgt.host.attrs["enable_fast_math"]
 
 
 def test_target_host_merge_1():
@@ -306,7 +295,6 @@ def test_target_with_host():
     assert tgt.host.attrs["max_threads_per_block"] == 1024
     assert tgt.host.attrs["thread_warp_size"] == 32
     assert tgt.host.attrs["registers_per_block"] == 32768
-    assert not tgt.host.attrs["enable_fast_math"]
 
 
 def test_target_attr_bool_value():

Reply via email to