This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6118b770b1 [Unity] Improved error checking for DataflowBlock in nested 
SeqExpr (#16195)
6118b770b1 is described below

commit 6118b770b1dada9630631286c76cd7501374446e
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Dec 14 14:00:30 2023 -0600

    [Unity] Improved error checking for DataflowBlock in nested SeqExpr (#16195)
    
    * [Unity] Improved error checking for DataflowBlock in nested SeqExpr
    
    Prior to this commit, the normalizer printed a warning whenever a
    lowering pass produced a `SeqExpr` containing a non-dataflow
    `BindingBlock` within a `DataflowBlock`.  This intermediate is only
    ill-formed if the non-dataflow `BindingBlock` makes use of any
    `DataflowVar` instances, as the `BindingBlock` can otherwise be
    hoisted out.
    
    This commit explicitly checks for use of `DataflowVar` in these cases,
    as this is most likely due to erroneous use of `BindingBlock` where
    `DataflowBlock` should be used, and changes the severity from
    `WARNING` to `FATAL`.  If there are no `DataflowVar` instances used in
    the `BindingBlock`, no warning is required.
    
    * Added unit test for error being caught.
---
 src/relax/ir/block_builder.cc                |  32 ++++-
 tests/python/relax/test_blockbuilder_core.py | 192 +++++++++++++++++++++++++++
 2 files changed, 221 insertions(+), 3 deletions(-)

diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index fda31e44a9..b445bde6f5 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -865,9 +865,35 @@ class Normalizer : public BlockBuilderImpl, private 
ExprFunctor<Expr(const Expr&
           // and thus flattened the inner SeqExprs already
           for (const BindingBlock& block : seq->blocks) {
             if (is_dataflow && !block->IsInstance<DataflowBlockNode>()) {
-              LOG(WARNING) << "Malformed AST: Seq expr nested inside a 
dataflow block contains a "
-                              "non-dataflow block! "
-                           << seq;
+              // A DataflowBlock occurring within a non-DataflowBlock
+              // usually is an error, resulting from return of a
+              // `BindingBlock`.  However, it may still be well-formed
+              // if there are no relax::DataflowVar instances used by
+              // the non-DataflowBlock.  This would result in multiple
+              // dataflow sections, split by non-dataflow portions,
+              // but would still be valid.
+              //
+              // Since the most common occurrence is due to mis-use,
+              // explicitly check for it here rather than waiting for a
+              // WellFormed check later on.
+
+              auto free_vars = FreeVars(SeqExpr({block}, 
Tuple(Array<Expr>{})));
+              Array<DataflowVar> free_dataflow_vars;
+              for (const auto& var : free_vars) {
+                if (auto opt = var.as<DataflowVar>()) {
+                  free_dataflow_vars.push_back(opt.value());
+                }
+              }
+
+              if (free_dataflow_vars.size()) {
+                LOG(FATAL)
+                    << "Malformed AST: "
+                    << "A DataflowVar may only be used within a DataflowBlock. 
 "
+                    << "The variable " << binding->var << " is defined within 
a DataflowBlock, "
+                    << "but is bound to a SeqExpr that contains non-dataflow 
BindingBlocks.  "
+                    << "These non-dataflow BindingBlocks use the DataflowVars "
+                    << free_dataflow_vars << ", which is invalid.";
+              }
             }
             ret.push_back(block);
           }
diff --git a/tests/python/relax/test_blockbuilder_core.py 
b/tests/python/relax/test_blockbuilder_core.py
index e1cbe37b18..255ef08560 100644
--- a/tests/python/relax/test_blockbuilder_core.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -728,5 +728,197 @@ def test_finalize_public_private_name_conflict():
     assert rx.analysis.well_formed(mod)
 
 
+def test_emit_nested_seqexpr_in_binding_block():
+    """May emit a SeqExpr inside a BindingBlock"""
+
+    bb = rx.BlockBuilder()
+
+    with bb.function("func", []):
+        lhs = bb.emit(rx.const(1, "int64"), "a")
+        rhs = bb.emit(rx.const(2, "int64"), "b")
+        out = bb.emit(rx.op.add(lhs, rhs), "c")
+        bb.emit_func_output(out)
+
+    seq_expr = bb.finalize()["func"].body
+
+    bb = rx.BlockBuilder()
+    with bb.function("func", [], private=True):
+        lhs = bb.emit(rx.const(3, "int64"), "d")
+        rhs = bb.emit(seq_expr, "e")
+        out = bb.emit(rx.op.add(lhs, rhs), "f")
+        bb.emit_func_output(out)
+
+    output = bb.finalize()["func"]
+
+    @R.function(private=True)
+    def expected():
+        d = R.const(3, "int64")
+        a = R.const(1, "int64")
+        b = R.const(2, "int64")
+        c = R.add(a, b)
+        e = c
+        f = R.add(d, e)
+        return f
+
+    tvm.ir.assert_structural_equal(expected, output)
+
+
+def test_emit_nested_dataflow_seqexpr_in_dataflow_block():
+    """May emit a SeqExpr with dataflow inside a DataflowBlock"""
+    bb = rx.BlockBuilder()
+
+    with bb.function("func", []):
+        with bb.dataflow():
+            lhs = bb.emit(rx.const(1, "int64"), "a")
+            rhs = bb.emit(rx.const(2, "int64"), "b")
+            out = bb.emit_output(rx.op.add(lhs, rhs), "c")
+        bb.emit_func_output(out)
+
+    seq_expr = bb.finalize()["func"].body
+
+    bb = rx.BlockBuilder()
+    with bb.function("func", [], private=True):
+        with bb.dataflow():
+            lhs = bb.emit(rx.const(3, "int64"), "d")
+            rhs = bb.emit(seq_expr, "e")
+            out = bb.emit_output(rx.op.add(lhs, rhs), "f")
+        bb.emit_func_output(out)
+
+    output = bb.finalize()["func"]
+
+    @R.function(private=True)
+    def expected():
+        with R.dataflow():
+            d = R.const(3, "int64")
+            a = R.const(1, "int64")
+            b = R.const(2, "int64")
+            c = R.add(a, b)
+            e = c
+            f = R.add(d, e)
+            R.output(c, f)
+        return f
+
+    tvm.ir.assert_structural_equal(expected, output)
+
+
+def test_emit_ill_formed_nested_seqexpr_in_dataflow_block():
+    """May emit a SeqExpr inside a DataflowBlock
+
+    This produces ill-formed code, but cannot be caught at the
+    normalizer.  See also
+    test_emit_well_formed_nested_seqexpr_in_dataflow_block.
+
+    """
+    bb = rx.BlockBuilder()
+
+    with bb.function("func", []):
+        lhs = bb.emit(rx.const(1, "int64"), "a")
+        rhs = bb.emit(rx.const(2, "int64"), "b")
+        out = bb.emit(rx.op.add(lhs, rhs), "c")
+        bb.emit_func_output(out)
+
+    seq_expr = bb.finalize()["func"].body
+
+    bb = rx.BlockBuilder()
+    with bb.function("func", [], private=True):
+        with bb.dataflow():
+            lhs = bb.emit(rx.const(3, "int64"), "d")
+            # This would be ill-formed, as it requires breaking up the
+            # DataflowBlock with a BindingBlock.
+            rhs = bb.emit(seq_expr, "e")
+
+            # We cannot throw an error at that point, because it is
+            # only the later usage of "d" that results in use of a
+            # DataflowVar outside of its home DataflowBlock.
+            out = bb.emit_output(rx.op.add(lhs, rhs), "f")
+        bb.emit_func_output(out)
+
+    output = bb.finalize()["func"]
+
+    assert not rx.analysis.well_formed(tvm.ir.IRModule.from_expr(output))
+
+
+def test_emit_well_formed_nested_seqexpr_in_dataflow_block():
+    """May emit a SeqExpr inside a DataflowBlock
+
+    This produces well-formed code, and should not have any output
+    produced by the normalizer.  See also
+    test_emit_ill_formed_nested_seqexpr_in_dataflow_block.
+    """
+    bb = rx.BlockBuilder()
+
+    with bb.function("func", []):
+        lhs = bb.emit(rx.const(1, "int64"), "a")
+        rhs = bb.emit(rx.const(2, "int64"), "b")
+        out = bb.emit(rx.op.add(lhs, rhs), "c")
+        bb.emit_func_output(out)
+
+    seq_expr = bb.finalize()["func"].body
+
+    bb = rx.BlockBuilder()
+    with bb.function("func", [], private=True):
+        with bb.dataflow():
+            lhs = bb.emit(rx.const(3, "int64"), "d")
+            # This similarly breaks up the DataflowBlock, with
+            # identical steps as the previous test up until this
+            # point.
+            rhs = bb.emit(seq_expr, "e")
+
+            # But the "d" variable isn't used, and so there aren't any
+            # usages of DataflowVar outside of their home
+            # DataflowBlock.
+            out = bb.emit_output(rhs, "f")
+        bb.emit_func_output(out)
+
+    output = bb.finalize()["func"]
+
+    assert rx.analysis.well_formed(tvm.ir.IRModule.from_expr(output))
+
+    @R.function(private=True)
+    def expected() -> R.Tensor((), dtype="int64"):
+        with R.dataflow():
+            d = R.const(3, "int64")
+            R.output()
+        a = R.const(1, "int64")
+        b = R.const(2, "int64")
+        c = R.add(a, b)
+        with R.dataflow():
+            e = c
+            f = e
+            R.output(f)
+        return f
+
+    tvm.ir.assert_structural_equal(expected, output)
+
+
+def test_error_when_unwrapping_dataflowvar():
+    """Checks for ill-formed use of DataflowVar at normalization
+
+    We can check for some illegal unwrapping of SeqExpr, though.  If
+    the inlined non-dataflow SeqExpr uses a DataflowVar, that should
+    trigger an error when the SeqExpr is being unwrapped.
+    """
+    bb = rx.BlockBuilder()
+
+    lhs = rx.Var("a", rx.TensorStructInfo(shape=[], dtype="int64"))
+
+    with bb.function("func", [lhs]):
+        rhs = rx.const(2, "int64")
+        out = bb.emit(rx.op.add(lhs, rhs))
+        bb.emit_func_output(out)
+
+    func = bb.finalize()["func"]
+
+    bb = rx.BlockBuilder()
+    with bb.function("func", [], private=True):
+        with bb.dataflow():
+            local_lhs = bb.emit(rx.const(3, "int64"), "local_a")
+            rhs = bb.emit(func.bind_params({lhs: local_lhs}).body, "f")
+            out = bb.emit_output(rhs, "f")
+
+        with pytest.raises(tvm.TVMError, match="Malformed AST"):
+            bb.emit_func_output(out)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to