This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f4cf9f578e [BugFix][TIRx] Fix bad-optional-access in BF16/FP8 legalize
passes for target-less PrimFuncs (#19383)
f4cf9f578e is described below
commit f4cf9f578e7de5960fa38429c20430001cc507af
Author: Soowon Jeong <[email protected]>
AuthorDate: Sat Apr 11 08:54:39 2026 +0900
[BugFix][TIRx] Fix bad-optional-access in BF16/FP8 legalize passes for
target-less PrimFuncs (#19383)
## Problem
`BF16ComputeLegalize`, `BF16StorageLegalize`, `FP8ComputeLegalize`, and
`FP8StorageLegalize` all call `f->GetAttr<Target>(kTarget).value()`
unconditionally. Host-side helper PrimFuncs produced during Relax
lowering (e.g. for `reshape`, `mean`) carry no target attribute;
`.value()` on an empty `Optional<Target>` aborts at runtime:
```
terminate called after throwing an instance of 'std::bad_optional_access'
```
This surfaces whenever `tvm.compile` targets CUDA on a model with
BF16/FP8 support compiled in, because the Relax lowering pipeline mixes
target-annotated GPU kernels with target-less CPU helpers in the same
module.
## Fix
Retrieve the target into `opt_target` and combine the `defined()` check
with `CheckDataTypeSupport`:
```c++
if (opt_target.defined() && CheckDataTypeSupport(opt_target.value(),
"tvm.contrib.nvcc.supports_bf16")) {
return f;
}
```
- If the target is absent: legalization runs. This is safe — it is a
no-op when no BF16/FP8 types are present, and it ensures correctness for
the edge case where a target-less function does contain such types.
- If the target confirms native support: legalization is skipped as
before.
- If the target lacks native support: legalization runs as before.
This is consistent with how explicit host targets (e.g. `llvm`) are
handled, which also go through legalization. Applied identically to all
four legalize pass lambdas.
---
src/tirx/transform/unsupported_dtype_legalize.cc | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/src/tirx/transform/unsupported_dtype_legalize.cc
b/src/tirx/transform/unsupported_dtype_legalize.cc
index 555f2bbbcc..402a0e8558 100644
--- a/src/tirx/transform/unsupported_dtype_legalize.cc
+++ b/src/tirx/transform/unsupported_dtype_legalize.cc
@@ -745,8 +745,8 @@ bool CheckDataTypeSupport(const Target& target, const
std::string& support_func_
Pass BF16ComputeLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
- auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
- if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) {
+ auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
+ if (opt_target.defined() && CheckDataTypeSupport(opt_target.value(),
"tvm.contrib.nvcc.supports_bf16")) {
return f;
}
return BF16ComputeLegalizer().Legalize(f);
@@ -761,8 +761,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
Pass BF16StorageLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
- auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
- if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) {
+ auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
+ if (opt_target.defined() && CheckDataTypeSupport(opt_target.value(),
"tvm.contrib.nvcc.supports_bf16")) {
return f;
}
return BF16StorageLegalizer().Legalize(f);
@@ -777,8 +777,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
Pass FP8ComputeLegalize(ffi::String promote_dtype) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
- auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
- if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
+ auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
+ if (opt_target.defined() && CheckDataTypeSupport(opt_target.value(),
"tvm.contrib.nvcc.supports_fp8")) {
return f;
}
return
FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f);
@@ -793,8 +793,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
Pass FP8StorageLegalize() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
- auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
- if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
+ auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
+ if (opt_target.defined() && CheckDataTypeSupport(opt_target.value(),
"tvm.contrib.nvcc.supports_fp8")) {
return f;
}
return FP8StorageLegalizer().Legalize(f);