This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3d966230ca Fix InternalError in StaticPlanBlockMemory when visiting
DataflowBlockNode (#17501)
3d966230ca is described below
commit 3d966230caa63b4ad8c3d6c86aaad27d5a8a0918
Author: Thrsu <[email protected]>
AuthorDate: Wed Nov 13 13:35:42 2024 +0800
Fix InternalError in StaticPlanBlockMemory when visiting DataflowBlockNode
(#17501)
This PR fixes an internal error #17488
This error happens because the visitor class StorageAllocatorBaseVisitor
does not correctly handle DataflowBlockNode instances.
Specifically, the VisitBindingBlock_ method is not overridden
for DataflowBlockNode, leading to an empty block_stack_
when it is expected to contain the current block.
To fix this issue, we need to override the VisitBindingBlock_
method for const DataflowBlockNode* in the
StorageAllocatorBaseVisitor class. By doing so, we ensure that
the block_stack_ is correctly managed when visiting dataflow
blocks, similar to how it is managed for regular binding blocks.
---
src/relax/transform/static_plan_block_memory.cc | 9 +++++
.../test_transform_static_plan_block_memory.py | 41 ++++++++++++++++++++++
2 files changed, 50 insertions(+)
diff --git a/src/relax/transform/static_plan_block_memory.cc
b/src/relax/transform/static_plan_block_memory.cc
index 74200526b6..44e338cbe8 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -314,6 +314,15 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
SetTokens(binding->var.get(), token_map_[binding->value.get()]);
}
+ void VisitBindingBlock_(const DataflowBlockNode* block) override {
+ // We maintain a block stack for token allocation-site and use-site check.
+ block_stack_.push_back(block);
+ ExprVisitor::VisitBindingBlock_(block);
+ ICHECK(!block_stack_.empty());
+ ICHECK(block_stack_.back() == block);
+ block_stack_.pop_back();
+ }
+
void VisitExpr_(const TupleNode* tuple) final {
Array<Tokens> tokens;
tokens.reserve(tuple->fields.size());
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 1150827b19..28015f0eec 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1504,5 +1504,46 @@ def test_view():
tvm.ir.assert_structural_equal(after, Expected)
+def test_with_dataflow():
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def exp(A: T.handle, B: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ cls = Before
+ with R.dataflow():
+ alloc: R.Tensor((10,), dtype="float32") =
R.builtin.alloc_tensor(
+ R.shape([10]), R.dtype("float32"), runtime_device_index=0
+ )
+ _: R.Tuple() = cls.exp(x, alloc)
+ gv: R.Tensor((10,), dtype="float32") = alloc
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def exp(A: T.handle, B: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ alloc: R.Tensor((10,), dtype="float32") =
R.builtin.alloc_tensor(
+ R.shape([10]), R.dtype("float32"), R.prim_value(0),
R.str("global")
+ )
+ cls.exp(x, alloc)
+ gv: R.Tensor((10,), dtype="float32") = alloc
+ R.output(gv)
+ return gv
+
+ after = relax.transform.StaticPlanBlockMemory()(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()