This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d966230ca Fix InternalError in StaticPlanBlockMemory when visiting 
DataflowBlockNode (#17501)
3d966230ca is described below

commit 3d966230caa63b4ad8c3d6c86aaad27d5a8a0918
Author: Thrsu <[email protected]>
AuthorDate: Wed Nov 13 13:35:42 2024 +0800

    Fix InternalError in StaticPlanBlockMemory when visiting DataflowBlockNode 
(#17501)
    
    This PR fixes an internal error #17488
    
    This error happens because the visitor class StorageAllocatorBaseVisitor
    does not correctly handle DataflowBlockNode instances.
    Specifically, the VisitBindingBlock_ method is not overridden
    for DataflowBlockNode, leading to an empty block_stack_
    when it is expected to contain the current block.
    
    To fix this issue, we need to override the VisitBindingBlock_
    method for const DataflowBlockNode* in the
    StorageAllocatorBaseVisitor class. By doing so, we ensure that
    the block_stack_ is correctly managed when visiting dataflow
    blocks, similar to how it is managed for regular binding blocks.
---
 src/relax/transform/static_plan_block_memory.cc    |  9 +++++
 .../test_transform_static_plan_block_memory.py     | 41 ++++++++++++++++++++++
 2 files changed, 50 insertions(+)

diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index 74200526b6..44e338cbe8 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -314,6 +314,15 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
     SetTokens(binding->var.get(), token_map_[binding->value.get()]);
   }
 
+  void VisitBindingBlock_(const DataflowBlockNode* block) override {
+    // We maintain a block stack for token allocation-site and use-site check.
+    block_stack_.push_back(block);
+    ExprVisitor::VisitBindingBlock_(block);
+    ICHECK(!block_stack_.empty());
+    ICHECK(block_stack_.back() == block);
+    block_stack_.pop_back();
+  }
+
   void VisitExpr_(const TupleNode* tuple) final {
     Array<Tokens> tokens;
     tokens.reserve(tuple->fields.size());
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 1150827b19..28015f0eec 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1504,5 +1504,46 @@ def test_view():
     tvm.ir.assert_structural_equal(after, Expected)
 
 
+def test_with_dataflow():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def exp(A: T.handle, B: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            cls = Before
+            with R.dataflow():
+                alloc: R.Tensor((10,), dtype="float32") = 
R.builtin.alloc_tensor(
+                    R.shape([10]), R.dtype("float32"), runtime_device_index=0
+                )
+                _: R.Tuple() = cls.exp(x, alloc)
+                gv: R.Tensor((10,), dtype="float32") = alloc
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def exp(A: T.handle, B: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                alloc: R.Tensor((10,), dtype="float32") = 
R.builtin.alloc_tensor(
+                    R.shape([10]), R.dtype("float32"), R.prim_value(0), 
R.str("global")
+                )
+                cls.exp(x, alloc)
+                gv: R.Tensor((10,), dtype="float32") = alloc
+                R.output(gv)
+            return gv
+
+    after = relax.transform.StaticPlanBlockMemory()(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to