This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f4cf9f578e [BugFix][TIRx] Fix bad-optional-access in BF16/FP8 legalize 
passes for target-less PrimFuncs (#19383)
f4cf9f578e is described below

commit f4cf9f578e7de5960fa38429c20430001cc507af
Author: Soowon Jeong <[email protected]>
AuthorDate: Sat Apr 11 08:54:39 2026 +0900

    [BugFix][TIRx] Fix bad-optional-access in BF16/FP8 legalize passes for 
target-less PrimFuncs (#19383)
    
    ## Problem
    
    `BF16ComputeLegalize`, `BF16StorageLegalize`, `FP8ComputeLegalize`, and
    `FP8StorageLegalize` all call `f->GetAttr<Target>(kTarget).value()`
    unconditionally. Host-side helper PrimFuncs produced during Relax
    lowering (e.g. for `reshape`, `mean`) carry no target attribute;
    `.value()` on an empty `Optional<Target>` aborts at runtime:
    
    ```
    terminate called after throwing an instance of 'std::bad_optional_access'
    ```
    
    This surfaces whenever `tvm.compile` targets CUDA on a model with
    BF16/FP8 support compiled in, because the Relax lowering pipeline mixes
    target-annotated GPU kernels with target-less CPU helpers in the same
    module.
    
    ## Fix
    
    Retrieve the target into `opt_target` and combine the `defined()` check
    with `CheckDataTypeSupport`:
    
    ```c++
    if (opt_target.defined() && CheckDataTypeSupport(opt_target.value(), 
"tvm.contrib.nvcc.supports_bf16")) {
        return f;
    }
    ```
    
    - If the target is absent: legalization runs. This is safe — it is a
    no-op when no BF16/FP8 types are present, and it ensures correctness for
    the edge case where a target-less function does contain such types.
    - If the target confirms native support: legalization is skipped as
    before.
    - If the target lacks native support: legalization runs as before.
    
    This is consistent with how explicit host targets (e.g. `llvm`) are
    handled, which also go through legalization. Applied identically to all
    four legalize pass lambdas.
---
 src/tirx/transform/unsupported_dtype_legalize.cc | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/src/tirx/transform/unsupported_dtype_legalize.cc 
b/src/tirx/transform/unsupported_dtype_legalize.cc
index 555f2bbbcc..402a0e8558 100644
--- a/src/tirx/transform/unsupported_dtype_legalize.cc
+++ b/src/tirx/transform/unsupported_dtype_legalize.cc
@@ -745,8 +745,8 @@ bool CheckDataTypeSupport(const Target& target, const 
std::string& support_func_
 
 Pass BF16ComputeLegalize() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
-    auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
-    if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) {
+    auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
+    if (opt_target.defined() && CheckDataTypeSupport(opt_target.value(), 
"tvm.contrib.nvcc.supports_bf16")) {
       return f;
     }
     return BF16ComputeLegalizer().Legalize(f);
@@ -761,8 +761,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 
 Pass BF16StorageLegalize() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
-    auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
-    if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) {
+    auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
+    if (opt_target.defined() && CheckDataTypeSupport(opt_target.value(), 
"tvm.contrib.nvcc.supports_bf16")) {
       return f;
     }
     return BF16StorageLegalizer().Legalize(f);
@@ -777,8 +777,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 
 Pass FP8ComputeLegalize(ffi::String promote_dtype) {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
-    auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
-    if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
+    auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
+    if (opt_target.defined() && CheckDataTypeSupport(opt_target.value(), 
"tvm.contrib.nvcc.supports_fp8")) {
       return f;
     }
     return 
FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f);
@@ -793,8 +793,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 
 Pass FP8StorageLegalize() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
-    auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
-    if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
+    auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
+    if (opt_target.defined() && CheckDataTypeSupport(opt_target.value(), 
"tvm.contrib.nvcc.supports_fp8")) {
       return f;
     }
     return FP8StorageLegalizer().Legalize(f);

Reply via email to