nautasolva opened a new issue, #16889:
URL: https://github.com/apache/tvm/issues/16889
When used on a block with a init statement, blockize creates a separate init
block that is not discoverable by any means. This hinders further scheduling,
like tensorizing the init block.
### Expected behavior
When using `blockize` on a loop that contains an init statement, the init is
moved to a new block `<block>_init` that should be discoverable with
`get_block` or `get_children_blocks` on the newly created outer block.
### Actual behavior
Init block exists in the TIR module but does not seem to be registered by
the schedule. `get_block("<block_init>)"` fails with `InternalError: Check
failed: (it != self_->stmt2ref.end()) is false`
<details><summary>Stacktrace</summary>
<p>
Traceback (most recent call last):
File "/home/dev/tvm_upstream/../tvm/playground/blockize_init_bug.py", line
31, in <module>
a_init = sch.get_block("A_init")
File "/home/dev/tvm_upstream/python/tvm/tir/schedule/_type_checker.py",
line 340, in wrap
return func(*args, **kwargs)
File "/home/dev/tvm_upstream/python/tvm/tir/schedule/schedule.py", line
499, in get_block
return _ffi_api.ScheduleGetBlock( # type: ignore # pylint:
disable=no-member
File "/home/dev/tvm_upstream/python/tvm/_ffi/_ctypes/packed_func.py", line
239, in __call__
raise_last_ffi_error()
File "/home/dev/tvm_upstream/python/tvm/_ffi/base.py", line 481, in
raise_last_ffi_error
raise py_err
File "/home/dev/tvm_upstream/src/tir/schedule/traced_schedule.cc", line
128, in tvm::tir::TracedScheduleNode::GetBlock(tvm::runtime::String const&,
tvm::runtime::Optional<tvm::runtime::String> const&)
BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name);
File "/home/dev/tvm_upstream/src/tir/schedule/concrete_schedule.cc", line
321, in tvm::tir::ConcreteScheduleNode::GetBlock(tvm::runtime::String const&,
tvm::runtime::Optional<tvm::runtime::String> const&)
Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, gv);
File
"/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 46,
in tvm::tir::GetBlocks(tvm::tir::ScheduleState const&, tvm::runtime::String
const&, tvm::GlobalVar const&)
finder(prim_func->body);
File "/home/dev/tvm_upstream/src/tir/ir/stmt_functor.cc", line 142, in
tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::BlockNode const*)
this->VisitStmt(op->init.value());
File
"/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 29,
in VisitStmt_
void VisitStmt_(const BlockNode* block) override {
File
"/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 32,
in VisitStmt_
ICHECK(it != self_->stmt2ref.end());
tvm.error.InternalError: Traceback (most recent call last):
5: tvm::tir::TracedScheduleNode::GetBlock(tvm::runtime::String const&,
tvm::runtime::Optional<tvm::runtime::String> const&)
at /home/dev/tvm_upstream/src/tir/schedule/traced_schedule.cc:128
4: tvm::tir::ConcreteScheduleNode::GetBlock(tvm::runtime::String const&,
tvm::runtime::Optional<tvm::runtime::String> const&)
at /home/dev/tvm_upstream/src/tir/schedule/concrete_schedule.cc:321
3: tvm::tir::GetBlocks(tvm::tir::ScheduleState const&,
tvm::runtime::String const&, tvm::GlobalVar const&)
at
/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc:46
2: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::BlockNode const*)
at /home/dev/tvm_upstream/src/tir/ir/stmt_functor.cc:142
1: VisitStmt_
at
/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc:29
0: VisitStmt_
at
/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc:32
File
"/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 32
</p>
</details>
### Environment
Reproducible on main (d4056ca79571d4265a12beeedd1b1565953df936)
### Steps to reproduce
```python
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main():
# with T.block("root"):
A_sum = T.alloc_buffer((1,), "float32")
A = T.alloc_buffer((1, 16), "float32")
for nn, ff in T.grid(1, 16):
with T.block("A"):
v_nn, v_ff = T.axis.remap("SR", [nn, ff])
T.reads(A[v_nn, v_ff])
T.writes(A_sum[v_nn])
with T.init():
A_sum[v_nn] = T.float32(0)
A_sum[v_nn] = A_sum[v_nn] + A[v_nn, v_ff]
sch = tvm.tir.Schedule(Module)
a = sch.get_block("A")
loop_n, loop_f = sch.get_loops(a)
sch.blockize(loop_f)
print(sch.mod) # <-- A_init exists
a_init = sch.get_block("A_init") # <-- fails with InternalError: Check
failed: (it != self_->stmt2ref.end()) is false
```
### Triage
* tir:schedule
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]