swjng opened a new pull request, #19514:
URL: https://github.com/apache/tvm/pull/19514
## Problem
Closes #17873.
`DefaultGPUSchedule` crashes when a PrimFunc body is a bare
`SBlockRealize` (a fully-scalar op with no enclosing loops and no iter
vars):
```
ValueError: Check failed: (sref->parent != nullptr) is false:
Cannot add loops on top of the root block
```
Minimal repro:
```python
@T.prim_func
def main(a: T.Buffer((), "float32"),
b: T.Buffer((), "float32"),
c: T.Buffer((), "float32")):
T.func_attr({"target": T.target("nvidia/geforce-rtx-3080")})
with T.sblock("scalar_add"):
c[()] = a[()] + b[()]
s_tir.transform.DefaultGPUSchedule()(M) # crashes
```
## Root Cause
The realized `scalar_add` block is itself the prim_func body's root
sref — it has no parent stmt to mutate. `ThreadBind`
(`src/s_tir/transform/default_gpu_schedule.cc`) reaches the
`loops.empty()` branch and calls `sch->AddUnitLoop(block)`, which fails
the `sref->parent != nullptr` check in `s_tir::AddUnitLoop`
(`src/s_tir/schedule/primitive/loop_transformation.cc:1166`).
The schedule infrastructure additionally requires the prim_func body
to be an `SBlockRealize` whose block is the function's root
(`GetRootPrimFunc` in `src/s_tir/schedule/analysis/analysis.cc:53`),
so the body cannot simply be wrapped in a top-level `For`.
## Fix
Before constructing the schedule, rewrite GPU-bound PrimFuncs whose
body is a bare-leaf `SBlockRealize` so the realized block is no longer
the root. The wrap conditions are intentionally narrow:
1. `func->body` is `SBlockRealize`,
2. the realized block has empty `iter_vars`, and
3. the block's body is **not** `For` or `SBlockRealize` (i.e. it is a
leaf computation, not the well-formed implicit root that wraps a
loop nest produced by the rest of the pipeline).
When all three hold, the body becomes:
```
SBlockRealize(
block=SBlock("root", body=
For(u, 0, 1, kSerial,
SBlockRealize(iter_values=[u],
block=<original block, iter_vars=[IterVar(0..1, vu, kDataPar)]>))))
```
The synthesised 1-extent data-parallel iter keeps
`iter_values.size() == iter_vars.size()` for downstream checks, and the
new For loop gives `ThreadBind` a real loop to bind to `blockIdx.x` /
`threadIdx.x`. Already-scheduled functions and host-only PrimFuncs are
skipped via the existing `IsScheduledOnGPU` / `kIsScheduled` gating.
## Testing
```
pytest
tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
```
10 passed (9 existing + 1 new `test_scalar_block_no_loops`). End-to-end
compile + execute on RTX 3080 (sm_86): the scalar repro returns the
expected `2.0 + 3.0 = 5.0`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]