This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a24204640e [TVMScript][Relax] Allow return statement in DataflowBlock
(#17131)
a24204640e is described below
commit a24204640efe3dcf519ca3388633a8a62a7600eb
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 18 13:01:43 2024 -0500
[TVMScript][Relax] Allow return statement in DataflowBlock (#17131)
Prior to this commit, TVMScript required the return value of a Relax
to be specified outside of any `with R.dataflow()` blocks. This
resulted in a common pattern, where the return value of a function was
first called with `R.output(ret_value)`, to mark `ret_value` as a
`tvm::relax::Var` instead of a `tvm::relax::DataflowVar`, followed
immediately by a `return ret_value` statement.
This commit updates the TVMScript parser to allow a `return` statement
inside a `with R.dataflow()` block. This is syntactic sugar that
is equivalent to calling `R.output`, followed by a `return`.
With this change, the following two TVMScript examples are now
equivalent. (Prior to this change, the `return_inside_dataflow`
example would raise an error during parsing.)
```python
@R.function(private=True)
def output_then_return(A: R.Tensor):
with R.dataflow():
B = R.add(A, A)
C = R.multiply(B, B)
R.output(C)
return C
@R.function(private=True)
def return_inside_dataflow(A: R.Tensor):
with R.dataflow():
B = R.add(A, A)
C = R.multiply(B, B)
return C
```
---
src/script/ir_builder/relax/frame.cc | 69 ++++++++++++-----------------
src/script/ir_builder/relax/ir.cc | 23 +++++++---
tests/python/relax/test_tvmscript_parser.py | 31 +++++++++++++
3 files changed, 75 insertions(+), 48 deletions(-)
diff --git a/src/script/ir_builder/relax/frame.cc
b/src/script/ir_builder/relax/frame.cc
index 3153c0770e..faf6bd6466 100644
--- a/src/script/ir_builder/relax/frame.cc
+++ b/src/script/ir_builder/relax/frame.cc
@@ -118,36 +118,23 @@ void BlockFrameNode::EnterWithScope() {
}
}
-class DataflowBlockRewriter : public tvm::relax::ExprMutator {
+class VarReplacer : public tvm::relax::ExprMutator {
public:
- static tvm::relax::DataflowBlock Rewrite(const tvm::relax::DataflowBlock&
block,
- const Array<tvm::relax::Var>&
output_vars) {
- DataflowBlockRewriter rewriter(output_vars);
- return
Downcast<tvm::relax::DataflowBlock>(rewriter.VisitBindingBlock(block));
+ explicit VarReplacer(
+ std::unordered_map<tvm::relax::Id, tvm::relax::Var, ObjectPtrHash,
ObjectPtrEqual>
+ var_remap) {
+ var_remap_ = std::move(var_remap);
}
- private:
- explicit DataflowBlockRewriter(const Array<tvm::relax::Var>& output_vars) {
- for (const tvm::relax::Var& var : output_vars) {
- output_var_set_.insert(var.get());
- }
- }
-
- tvm::relax::Var VisitVarDef_(const tvm::relax::DataflowVarNode* op) final {
- auto it = output_var_set_.find(op);
- if (it != output_var_set_.end()) {
- // Rewrite dataflow vars to global vars
- auto n = make_object<tvm::relax::VarNode>(*op);
- tvm::relax::Var new_var(n);
- this->var_remap_[op->vid] = new_var;
- return new_var;
+ tvm::relax::Var VisitVarDef(const tvm::relax::Var& var) override {
+ // ExprMutator only applies var_remap_ at usage sites. This
+ // applies var_remap_ at each definition site as well.
+ if (auto it = var_remap_.find(var->vid); it != var_remap_.end()) {
+ return it->second;
} else {
- return GetRef<tvm::relax::Var>(op);
+ return var;
}
}
-
- private:
- std::unordered_set<const tvm::relax::VarNode*> output_var_set_;
};
void BlockFrameNode::ExitWithScope() {
@@ -164,25 +151,27 @@ void BlockFrameNode::ExitWithScope() {
// Step 3. Rewrite the dataflow block.
if (is_dataflow) {
- // Step 3.1. Rewrite block binding
- block =
DataflowBlockRewriter::Rewrite(Downcast<tvm::relax::DataflowBlock>(block),
output_vars);
-
- // Step 3.2. Collect global vars' reference in bindings
- Map<tvm::relax::Id, tvm::relax::Var> new_global_vars;
- for (const tvm::relax::Binding& binding : block->bindings) {
- if (!binding->var->IsInstance<tvm::relax::DataflowVarNode>()) {
- new_global_vars.Set(binding->var->vid, binding->var);
- }
+ // Step 3.0. Define a map to replace variables
+ Array<tvm::relax::Var> new_output_vars;
+ std::unordered_map<tvm::relax::Id, tvm::relax::Var, ObjectPtrHash,
ObjectPtrEqual> var_remap;
+ for (const auto& output_var : output_vars) {
+ tvm::relax::Var new_output_var(output_var->name_hint(),
GetStructInfo(output_var));
+ new_output_vars.push_back(new_output_var);
+ var_remap[output_var->vid] = new_output_var;
}
+ VarReplacer mutator(std::move(var_remap));
+
+ // Step 3.1. Rewrite block binding
+ block = mutator.VisitBindingBlock(block);
// Step 3.3. Rewrite output vars
- Array<tvm::relax::Var> new_output_vars;
- for (const auto& var : output_vars) {
- auto it = new_global_vars.find(var->vid);
- ICHECK(it != new_global_vars.end());
- new_output_vars.push_back((*it).second);
- }
output_vars = std::move(new_output_vars);
+
+ // Step 3.4 Rewrite usage of output var, if any
+ auto function = FindFunctionFrame("R.dataflow()");
+ if (function->output.defined()) {
+ function->output = mutator.VisitExpr(function->output.value());
+ }
}
// Step 3. Get the last frame from the IRBuilder frame stack.
@@ -196,8 +185,6 @@ void BlockFrameNode::ExitWithScope() {
// Step 5. Push the block frame into the corresponding field of the last
frame.
if (const auto* seq_frame = last_frame.as<SeqExprFrameNode>()) {
- ICHECK(!seq_frame->output.defined())
- << "The function is not expected to have output values when emitting
blocks.";
auto frame = GetRef<SeqExprFrame>(seq_frame);
frame->binding_blocks.push_back(block);
} else {
diff --git a/src/script/ir_builder/relax/ir.cc
b/src/script/ir_builder/relax/ir.cc
index 453c7fdb55..b2e75d0c36 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/script/ir_builder/relax/ir.cc
@@ -117,20 +117,29 @@ void FuncRetValue(const tvm::relax::Expr& value) {
const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder();
tvm::relax::Expr normalized_value = block_builder->Normalize(value);
+ IRBuilder ir_builder = IRBuilder::Current();
+
// Step 1. The current Relax TVMScript syntax only allows function return
appearing at the end of
// a function body. Therefore if there is any unended block frame when
dealing with function
// return, we should end the block frame.
- Optional<BlockFrame> block_frame =
IRBuilder::Current()->GetLastFrame<BlockFrame>();
- if (block_frame.defined()) {
- block_frame.value()->ExitWithScope();
- ICHECK(!IRBuilder::Current()->FindFrame<BlockFrame>())
- << "ValueError: Relax functions don't support return in true/false
branch of If Node.";
+
+ if (auto opt = ir_builder->GetLastFrame<BlockFrame>()) {
+ auto block_frame = opt.value();
+ for (const auto& var : tvm::relax::FreeVars(normalized_value)) {
+ if (var->IsInstance<tvm::relax::DataflowVarNode>()) {
+ block_frame->output_vars.push_back(var);
+ }
+ }
}
// Step 2. Add the output value to the function frame.
FunctionFrame frame = FindFunctionFrame("return");
CHECK(!frame->output.defined())
- << "ValueError: Relax functions don't support multiple return statement.
Please make sure "
- "the return statement appears at the end of function.";
+ << "ValueError: "
+ << "Relax functions do not support multiple return statement. "
+ << "However, return of " << normalized_value << " occurred after a
return of "
+ << frame->output << ". "
+ << "Please make sure function only has a single return statement, "
+ << "which appears at the end of function.";
frame->output = std::move(normalized_value);
}
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index fd465f3201..fa62d14848 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -2410,5 +2410,36 @@ def
test_conditional_may_use_symbolic_variables_from_function_scope():
tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo)
+def test_return_from_dataflow_block():
+ """Return statements imply
+
+ The `R.output` statement in a `R.dataflow()` block marks a
+ variable that should be a `relax.Var` instead of a
+ `relax.DataflowVar`, allowing it to be used outside of the
+ `DataflowBlock` that defined it. A relax function's output is not
+ part of any binding, and must not contain any `DataflowVar`, so
+ these are exposed implicitly.
+
+ """
+
+ @R.function(private=True)
+ def output_then_return(A: R.Tensor([16], "float16")):
+ with R.dataflow():
+ B = R.add(A, A)
+ C = R.multiply(B, B)
+ R.output(C)
+
+ return C
+
+ @R.function(private=True)
+ def return_inside_dataflow(A: R.Tensor([16], "float16")):
+ with R.dataflow():
+ B = R.add(A, A)
+ C = R.multiply(B, B)
+ return C
+
+ tvm.ir.assert_structural_equal(output_then_return, return_inside_dataflow)
+
+
if __name__ == "__main__":
tvm.testing.main()