This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6118b770b1 [Unity] Improved error checking for DataflowBlock in nested
SeqExpr (#16195)
6118b770b1 is described below
commit 6118b770b1dada9630631286c76cd7501374446e
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Dec 14 14:00:30 2023 -0600
[Unity] Improved error checking for DataflowBlock in nested SeqExpr (#16195)
* [Unity] Improved error checking for DataflowBlock in nested SeqExpr
Prior to this commit, the normalizer printed a warning whenever a
lowering pass produced a `SeqExpr` containing a non-dataflow
`BindingBlock` within a `DataflowBlock`. This intermediate is only
ill-formed if the non-dataflow `BindingBlock` makes use of any
`DataflowVar` instances, as the `BindingBlock` can otherwise be
hoisted out.
This commit explicitly checks for use of `DataflowVar` in these cases,
as this is most likely due to erroneous use of `BindingBlock` where
`DataflowBlock` should be used, and changes the severity from
`WARNING` to `FATAL`. If there are no `DataflowVar` instances used in
the `BindingBlock`, no warning is required.
* Added unit test for error being caught.
---
src/relax/ir/block_builder.cc | 32 ++++-
tests/python/relax/test_blockbuilder_core.py | 192 +++++++++++++++++++++++++++
2 files changed, 221 insertions(+), 3 deletions(-)
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index fda31e44a9..b445bde6f5 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -865,9 +865,35 @@ class Normalizer : public BlockBuilderImpl, private
ExprFunctor<Expr(const Expr&
// and thus flattened the inner SeqExprs already
for (const BindingBlock& block : seq->blocks) {
if (is_dataflow && !block->IsInstance<DataflowBlockNode>()) {
- LOG(WARNING) << "Malformed AST: Seq expr nested inside a
dataflow block contains a "
- "non-dataflow block! "
- << seq;
+ // A DataflowBlock occurring within a non-DataflowBlock
+ // usually is an error, resulting from return of a
+ // `BindingBlock`. However, it may still be well-formed
+ // if there are no relax::DataflowVar instances used by
+ // the non-DataflowBlock. This would result in multiple
+ // dataflow sections, split by non-dataflow portions,
+ // but would still be valid.
+ //
+ // Since the most common occurrence is due to mis-use,
+ // explicitly check for it here rather than waiting for a
+ // WellFormed check later on.
+
+ auto free_vars = FreeVars(SeqExpr({block},
Tuple(Array<Expr>{})));
+ Array<DataflowVar> free_dataflow_vars;
+ for (const auto& var : free_vars) {
+ if (auto opt = var.as<DataflowVar>()) {
+ free_dataflow_vars.push_back(opt.value());
+ }
+ }
+
+ if (free_dataflow_vars.size()) {
+ LOG(FATAL)
+ << "Malformed AST: "
+ << "A DataflowVar may only be used within a DataflowBlock.
"
+ << "The variable " << binding->var << " is defined within
a DataflowBlock, "
+ << "but is bound to a SeqExpr that contains non-dataflow
BindingBlocks. "
+ << "These non-dataflow BindingBlocks use the DataflowVars "
+ << free_dataflow_vars << ", which is invalid.";
+ }
}
ret.push_back(block);
}
diff --git a/tests/python/relax/test_blockbuilder_core.py
b/tests/python/relax/test_blockbuilder_core.py
index e1cbe37b18..255ef08560 100644
--- a/tests/python/relax/test_blockbuilder_core.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -728,5 +728,197 @@ def test_finalize_public_private_name_conflict():
assert rx.analysis.well_formed(mod)
+def test_emit_nested_seqexpr_in_binding_block():
+ """May emit a SeqExpr inside a BindingBlock"""
+
+ bb = rx.BlockBuilder()
+
+ with bb.function("func", []):
+ lhs = bb.emit(rx.const(1, "int64"), "a")
+ rhs = bb.emit(rx.const(2, "int64"), "b")
+ out = bb.emit(rx.op.add(lhs, rhs), "c")
+ bb.emit_func_output(out)
+
+ seq_expr = bb.finalize()["func"].body
+
+ bb = rx.BlockBuilder()
+ with bb.function("func", [], private=True):
+ lhs = bb.emit(rx.const(3, "int64"), "d")
+ rhs = bb.emit(seq_expr, "e")
+ out = bb.emit(rx.op.add(lhs, rhs), "f")
+ bb.emit_func_output(out)
+
+ output = bb.finalize()["func"]
+
+ @R.function(private=True)
+ def expected():
+ d = R.const(3, "int64")
+ a = R.const(1, "int64")
+ b = R.const(2, "int64")
+ c = R.add(a, b)
+ e = c
+ f = R.add(d, e)
+ return f
+
+ tvm.ir.assert_structural_equal(expected, output)
+
+
+def test_emit_nested_dataflow_seqexpr_in_dataflow_block():
+ """May emit a SeqExpr with dataflow inside a DataflowBlock"""
+ bb = rx.BlockBuilder()
+
+ with bb.function("func", []):
+ with bb.dataflow():
+ lhs = bb.emit(rx.const(1, "int64"), "a")
+ rhs = bb.emit(rx.const(2, "int64"), "b")
+ out = bb.emit_output(rx.op.add(lhs, rhs), "c")
+ bb.emit_func_output(out)
+
+ seq_expr = bb.finalize()["func"].body
+
+ bb = rx.BlockBuilder()
+ with bb.function("func", [], private=True):
+ with bb.dataflow():
+ lhs = bb.emit(rx.const(3, "int64"), "d")
+ rhs = bb.emit(seq_expr, "e")
+ out = bb.emit_output(rx.op.add(lhs, rhs), "f")
+ bb.emit_func_output(out)
+
+ output = bb.finalize()["func"]
+
+ @R.function(private=True)
+ def expected():
+ with R.dataflow():
+ d = R.const(3, "int64")
+ a = R.const(1, "int64")
+ b = R.const(2, "int64")
+ c = R.add(a, b)
+ e = c
+ f = R.add(d, e)
+ R.output(c, f)
+ return f
+
+ tvm.ir.assert_structural_equal(expected, output)
+
+
+def test_emit_ill_formed_nested_seqexpr_in_dataflow_block():
+ """May emit a SeqExpr inside a DataflowBlock
+
+ This produces ill-formed code, but cannot be caught at the
+ normalizer. See also
+ test_emit_well_formed_nested_seqexpr_in_dataflow_block.
+
+ """
+ bb = rx.BlockBuilder()
+
+ with bb.function("func", []):
+ lhs = bb.emit(rx.const(1, "int64"), "a")
+ rhs = bb.emit(rx.const(2, "int64"), "b")
+ out = bb.emit(rx.op.add(lhs, rhs), "c")
+ bb.emit_func_output(out)
+
+ seq_expr = bb.finalize()["func"].body
+
+ bb = rx.BlockBuilder()
+ with bb.function("func", [], private=True):
+ with bb.dataflow():
+ lhs = bb.emit(rx.const(3, "int64"), "d")
+ # This would be ill-formed, as it requires breaking up the
+ # DataflowBlock with a BindingBlock.
+ rhs = bb.emit(seq_expr, "e")
+
+ # We cannot throw an error at that point, because it is
+ # only the later usage of "d" that results in use of a
+ # DataflowVar outside of its home DataflowBlock.
+ out = bb.emit_output(rx.op.add(lhs, rhs), "f")
+ bb.emit_func_output(out)
+
+ output = bb.finalize()["func"]
+
+ assert not rx.analysis.well_formed(tvm.ir.IRModule.from_expr(output))
+
+
+def test_emit_well_formed_nested_seqexpr_in_dataflow_block():
+ """May emit a SeqExpr inside a DataflowBlock
+
+ This produces well-formed code, and should not have any output
+ produced by the normalizer. See also
+ test_emit_ill_formed_nested_seqexpr_in_dataflow_block.
+ """
+ bb = rx.BlockBuilder()
+
+ with bb.function("func", []):
+ lhs = bb.emit(rx.const(1, "int64"), "a")
+ rhs = bb.emit(rx.const(2, "int64"), "b")
+ out = bb.emit(rx.op.add(lhs, rhs), "c")
+ bb.emit_func_output(out)
+
+ seq_expr = bb.finalize()["func"].body
+
+ bb = rx.BlockBuilder()
+ with bb.function("func", [], private=True):
+ with bb.dataflow():
+ lhs = bb.emit(rx.const(3, "int64"), "d")
+ # This similarly breaks up the DataflowBlock, with
+ # identical steps as the previous test up until this
+ # point.
+ rhs = bb.emit(seq_expr, "e")
+
+ # But the "d" variable isn't used, and so there aren't any
+ # usages of DataflowVar outside of their home
+ # DataflowBlock.
+ out = bb.emit_output(rhs, "f")
+ bb.emit_func_output(out)
+
+ output = bb.finalize()["func"]
+
+ assert rx.analysis.well_formed(tvm.ir.IRModule.from_expr(output))
+
+ @R.function(private=True)
+ def expected() -> R.Tensor((), dtype="int64"):
+ with R.dataflow():
+ d = R.const(3, "int64")
+ R.output()
+ a = R.const(1, "int64")
+ b = R.const(2, "int64")
+ c = R.add(a, b)
+ with R.dataflow():
+ e = c
+ f = e
+ R.output(f)
+ return f
+
+ tvm.ir.assert_structural_equal(expected, output)
+
+
+def test_error_when_unwrapping_dataflowvar():
+ """Checks for ill-formed use of DataflowVar at normalization
+
+ We can check for some illegal unwrapping of SeqExpr, though. If
+ the inlined non-dataflow SeqExpr uses a DataflowVar, that should
+ trigger an error when the SeqExpr is being unwrapped.
+ """
+ bb = rx.BlockBuilder()
+
+ lhs = rx.Var("a", rx.TensorStructInfo(shape=[], dtype="int64"))
+
+ with bb.function("func", [lhs]):
+ rhs = rx.const(2, "int64")
+ out = bb.emit(rx.op.add(lhs, rhs))
+ bb.emit_func_output(out)
+
+ func = bb.finalize()["func"]
+
+ bb = rx.BlockBuilder()
+ with bb.function("func", [], private=True):
+ with bb.dataflow():
+ local_lhs = bb.emit(rx.const(3, "int64"), "local_a")
+ rhs = bb.emit(func.bind_params({lhs: local_lhs}).body, "f")
+ out = bb.emit_output(rhs, "f")
+
+ with pytest.raises(tvm.TVMError, match="Malformed AST"):
+ bb.emit_func_output(out)
+
+
if __name__ == "__main__":
tvm.testing.main()