This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a24204640e [TVMScript][Relax] Allow return statement in DataflowBlock 
(#17131)
a24204640e is described below

commit a24204640efe3dcf519ca3388633a8a62a7600eb
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 18 13:01:43 2024 -0500

    [TVMScript][Relax] Allow return statement in DataflowBlock (#17131)
    
    Prior to this commit, TVMScript required the return value of a Relax
    to be specified outside of any `with R.dataflow()` blocks.  This
    resulted in a common pattern, where the return value of a function was
    first called with `R.output(ret_value)`, to mark `ret_value` as a
    `tvm::relax::Var` instead of a `tvm::relax::DataflowVar`, followed
    immediately by a `return ret_value` statement.
    
    This commit updates the TVMScript parser to allow a `return` statement
    inside a `with R.dataflow()` block.  This is syntactic sugar that
    is equivalent to calling `R.output`, followed by a `return`.
    
    With this change, the following two TVMScript examples are now
    equivalent.  (Prior to this change, the `return_inside_dataflow`
    example would raise an error during parsing.)
    
    ```python
    @R.function(private=True)
    def output_then_return(A: R.Tensor):
        with R.dataflow():
            B = R.add(A, A)
            C = R.multiply(B, B)
            R.output(C)
    
        return C
    
    @R.function(private=True)
    def return_inside_dataflow(A: R.Tensor):
        with R.dataflow():
            B = R.add(A, A)
            C = R.multiply(B, B)
            return C
    ```
---
 src/script/ir_builder/relax/frame.cc        | 69 ++++++++++++-----------------
 src/script/ir_builder/relax/ir.cc           | 23 +++++++---
 tests/python/relax/test_tvmscript_parser.py | 31 +++++++++++++
 3 files changed, 75 insertions(+), 48 deletions(-)

diff --git a/src/script/ir_builder/relax/frame.cc 
b/src/script/ir_builder/relax/frame.cc
index 3153c0770e..faf6bd6466 100644
--- a/src/script/ir_builder/relax/frame.cc
+++ b/src/script/ir_builder/relax/frame.cc
@@ -118,36 +118,23 @@ void BlockFrameNode::EnterWithScope() {
   }
 }
 
-class DataflowBlockRewriter : public tvm::relax::ExprMutator {
+class VarReplacer : public tvm::relax::ExprMutator {
  public:
-  static tvm::relax::DataflowBlock Rewrite(const tvm::relax::DataflowBlock& 
block,
-                                           const Array<tvm::relax::Var>& 
output_vars) {
-    DataflowBlockRewriter rewriter(output_vars);
-    return 
Downcast<tvm::relax::DataflowBlock>(rewriter.VisitBindingBlock(block));
+  explicit VarReplacer(
+      std::unordered_map<tvm::relax::Id, tvm::relax::Var, ObjectPtrHash, 
ObjectPtrEqual>
+          var_remap) {
+    var_remap_ = std::move(var_remap);
   }
 
- private:
-  explicit DataflowBlockRewriter(const Array<tvm::relax::Var>& output_vars) {
-    for (const tvm::relax::Var& var : output_vars) {
-      output_var_set_.insert(var.get());
-    }
-  }
-
-  tvm::relax::Var VisitVarDef_(const tvm::relax::DataflowVarNode* op) final {
-    auto it = output_var_set_.find(op);
-    if (it != output_var_set_.end()) {
-      // Rewrite dataflow vars to global vars
-      auto n = make_object<tvm::relax::VarNode>(*op);
-      tvm::relax::Var new_var(n);
-      this->var_remap_[op->vid] = new_var;
-      return new_var;
+  tvm::relax::Var VisitVarDef(const tvm::relax::Var& var) override {
+    // ExprMutator only applies var_remap_ at usage sites.  This
+    // applies var_remap_ at each definition site as well.
+    if (auto it = var_remap_.find(var->vid); it != var_remap_.end()) {
+      return it->second;
     } else {
-      return GetRef<tvm::relax::Var>(op);
+      return var;
     }
   }
-
- private:
-  std::unordered_set<const tvm::relax::VarNode*> output_var_set_;
 };
 
 void BlockFrameNode::ExitWithScope() {
@@ -164,25 +151,27 @@ void BlockFrameNode::ExitWithScope() {
 
   // Step 3. Rewrite the dataflow block.
   if (is_dataflow) {
-    // Step 3.1. Rewrite block binding
-    block = 
DataflowBlockRewriter::Rewrite(Downcast<tvm::relax::DataflowBlock>(block), 
output_vars);
-
-    // Step 3.2. Collect global vars' reference in bindings
-    Map<tvm::relax::Id, tvm::relax::Var> new_global_vars;
-    for (const tvm::relax::Binding& binding : block->bindings) {
-      if (!binding->var->IsInstance<tvm::relax::DataflowVarNode>()) {
-        new_global_vars.Set(binding->var->vid, binding->var);
-      }
+    // Step 3.0.  Define a map to replace variables
+    Array<tvm::relax::Var> new_output_vars;
+    std::unordered_map<tvm::relax::Id, tvm::relax::Var, ObjectPtrHash, 
ObjectPtrEqual> var_remap;
+    for (const auto& output_var : output_vars) {
+      tvm::relax::Var new_output_var(output_var->name_hint(), 
GetStructInfo(output_var));
+      new_output_vars.push_back(new_output_var);
+      var_remap[output_var->vid] = new_output_var;
     }
+    VarReplacer mutator(std::move(var_remap));
+
+    // Step 3.1. Rewrite block binding
+    block = mutator.VisitBindingBlock(block);
 
     // Step 3.3. Rewrite output vars
-    Array<tvm::relax::Var> new_output_vars;
-    for (const auto& var : output_vars) {
-      auto it = new_global_vars.find(var->vid);
-      ICHECK(it != new_global_vars.end());
-      new_output_vars.push_back((*it).second);
-    }
     output_vars = std::move(new_output_vars);
+
+    // Step 3.4 Rewrite usage of output var, if any
+    auto function = FindFunctionFrame("R.dataflow()");
+    if (function->output.defined()) {
+      function->output = mutator.VisitExpr(function->output.value());
+    }
   }
 
   // Step 3. Get the last frame from the IRBuilder frame stack.
@@ -196,8 +185,6 @@ void BlockFrameNode::ExitWithScope() {
 
   // Step 5. Push the block frame into the corresponding field of the last 
frame.
   if (const auto* seq_frame = last_frame.as<SeqExprFrameNode>()) {
-    ICHECK(!seq_frame->output.defined())
-        << "The function is not expected to have output values when emitting 
blocks.";
     auto frame = GetRef<SeqExprFrame>(seq_frame);
     frame->binding_blocks.push_back(block);
   } else {
diff --git a/src/script/ir_builder/relax/ir.cc 
b/src/script/ir_builder/relax/ir.cc
index 453c7fdb55..b2e75d0c36 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/script/ir_builder/relax/ir.cc
@@ -117,20 +117,29 @@ void FuncRetValue(const tvm::relax::Expr& value) {
   const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder();
   tvm::relax::Expr normalized_value = block_builder->Normalize(value);
 
+  IRBuilder ir_builder = IRBuilder::Current();
+
   // Step 1. The current Relax TVMScript syntax only allows function return 
appearing at the end of
   // a function body. Therefore if there is any unended block frame when 
dealing with function
   // return, we should end the block frame.
-  Optional<BlockFrame> block_frame = 
IRBuilder::Current()->GetLastFrame<BlockFrame>();
-  if (block_frame.defined()) {
-    block_frame.value()->ExitWithScope();
-    ICHECK(!IRBuilder::Current()->FindFrame<BlockFrame>())
-        << "ValueError: Relax functions don't support return in true/false 
branch of If Node.";
+
+  if (auto opt = ir_builder->GetLastFrame<BlockFrame>()) {
+    auto block_frame = opt.value();
+    for (const auto& var : tvm::relax::FreeVars(normalized_value)) {
+      if (var->IsInstance<tvm::relax::DataflowVarNode>()) {
+        block_frame->output_vars.push_back(var);
+      }
+    }
   }
   // Step 2. Add the output value to the function frame.
   FunctionFrame frame = FindFunctionFrame("return");
   CHECK(!frame->output.defined())
-      << "ValueError: Relax functions don't support multiple return statement. 
Please make sure "
-         "the return statement appears at the end of function.";
+      << "ValueError: "
+      << "Relax functions do not support multiple return statement.  "
+      << "However, return of " << normalized_value << " occurred after a 
return of "
+      << frame->output << ".  "
+      << "Please make sure function only has a single return statement, "
+      << "which appears at the end of function.";
 
   frame->output = std::move(normalized_value);
 }
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index fd465f3201..fa62d14848 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -2410,5 +2410,36 @@ def 
test_conditional_may_use_symbolic_variables_from_function_scope():
     tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo)
 
 
+def test_return_from_dataflow_block():
+    """Return statements imply
+
+    The `R.output` statement in a `R.dataflow()` block marks a
+    variable that should be a `relax.Var` instead of a
+    `relax.DataflowVar`, allowing it to be used outside of the
+    `DataflowBlock` that defined it.  A relax function's output is not
+    part of any binding, and must not contain any `DataflowVar`, so
+    these are exposed implicitly.
+
+    """
+
+    @R.function(private=True)
+    def output_then_return(A: R.Tensor([16], "float16")):
+        with R.dataflow():
+            B = R.add(A, A)
+            C = R.multiply(B, B)
+            R.output(C)
+
+        return C
+
+    @R.function(private=True)
+    def return_inside_dataflow(A: R.Tensor([16], "float16")):
+        with R.dataflow():
+            B = R.add(A, A)
+            C = R.multiply(B, B)
+            return C
+
+    tvm.ir.assert_structural_equal(output_then_return, return_inside_dataflow)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to