This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new d9e1659d2a [Unity][Transform] Elide redundant bindings of dataflow 
vars (#15341)
d9e1659d2a is described below

commit d9e1659d2a92815341e96105f7e1dde6c9d797a6
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Wed Aug 2 18:10:13 2023 -0400

    [Unity][Transform] Elide redundant bindings of dataflow vars (#15341)
    
    This PR addresses the second issue in #15245. In particular, it handles the 
case that can arise when using `CanonicalizeBindings` of having a dataflow var 
whose only use is being bound to the output var, like so:
    
    ```python
    @R.function
    def main() -> R.Tensor((), "int32"):
            with R.dataflow():
                y = R.const(1)
                n = y
                R.output(n)
            return n
    ```
    The only use for `y` is to be bound to `n`, which is the `DataflowBlock`'s 
output.
    
    This PR introduces adds a pass called `FoldDataflowBlockOutput` that 
detects this case and eliminates the intermediate binding, leaving in this 
example simply this:
    ```python
    @R.function
    def main() -> R.Tensor((), "int32"):
        with R.dataflow():
            n = R.const(1)
            R.output(n)
        return n
    ```
    
    * Coalesce intermediate dataflow vars in dead code elimination
    
    * Fix typo in comment
    
    * Move dataflow var elision to its own pass
    
    * Remove redundant whitespace
---
 include/tvm/relax/transform.h                      |  10 ++
 python/tvm/relax/transform/transform.py            |  15 ++
 src/relax/transform/fold_dataflow_block_output.cc  | 192 +++++++++++++++++++++
 .../test_transform_fold_dataflow_block_output.py   | 149 ++++++++++++++++
 4 files changed, 366 insertions(+)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index dc2476f383..8d01262aab 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -189,6 +189,16 @@ TVM_DLL Pass BindParams(String func_name, Map<String, 
runtime::NDArray> params);
  */
 TVM_DLL Pass FoldConstant();
 
+/*!
+ * \brief If a dataflow var is used only in a binding to the dataflow block
+ * output var (i.e., a non-dataflow var), this removes the dataflow var
+ * and replaces the output var's binding with the dataflow var's direct 
definition.
+ *
+ * This "cleans up" a situation that commonly arises when using 
`CanonicalizeBindings`
+ * and `DeadCodeElimination`.
+ **/
+TVM_DLL Pass FoldDataflowBlockOutput();
+
 /*!
  * \brief Legalize high-level operator calls in Relax functions to call_tir
  * with corresponding low-level TIR PrimFuncs.
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 73267c43ae..f7b0b4e9dd 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -453,6 +453,21 @@ def FoldConstant() -> tvm.ir.transform.Pass:
     return _ffi_api.FoldConstant()  # type: ignore
 
 
+def FoldDataflowBlockOutput() -> tvm.ir.transform.Pass:
+    """If a dataflow var is used only in a binding to the dataflow block
+    output var (i.e., a non-dataflow var), this removes the dataflow var
+    and replaces the output var's binding with the dataflow var's direct 
definition.
+
+    This "cleans up" a situation that commonly arises when using 
`CanonicalizeBindings`
+    and `DeadCodeElimination`.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.FoldDataflowBlockOutput()  # type: ignore
+
+
 def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
     """Annotate Op Pattern Kind for TIR functions
 
diff --git a/src/relax/transform/fold_dataflow_block_output.cc 
b/src/relax/transform/fold_dataflow_block_output.cc
new file mode 100644
index 0000000000..dc6e8a68c5
--- /dev/null
+++ b/src/relax/transform/fold_dataflow_block_output.cc
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/fold_dataflow_block_output.cc
+ * \brief Pass that folds dataflow vars used only for binding
+ *   the dataflow block output
+ *   directly into the output
+ *
+ * If a dataflow var is used only in a binding to the dataflow block's
+ * output var (a non-dataflow var), this pass removes the dataflow var
+ * binding from the block and uses the dataflow var's definition
+ * directly in the output binding.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+// If a dataflow var is used *only* as the RHS of a binding to the dataflow 
block output
+// (i.e., an ordinary var), then we can get rid of that dataflow var and bind 
the DF var's
+// definition directly to the output.
+DataflowBlock FoldDataflowBlockOutput(const DataflowBlock& block) {
+  // helper: gather all dataflow vars inside an expression
+  class DataflowVarGatherer : public ExprVisitor {
+   public:
+    // ignore inner functions
+    void VisitExpr_(const FunctionNode* _) override {}
+
+    void VisitExpr_(const DataflowVarNode* var) override { 
vars_.insert(GetRef<DataflowVar>(var)); }
+
+    std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> 
Gather(const Expr& expr) {
+      VisitExpr(expr);
+      return vars_;
+    }
+
+    std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> vars_;
+  };
+
+  // first we search for dataflow vars for which the condition is met:
+  // exclude if found anywhere other than RHS of a binding to an ordinary var 
(or more than once)
+  // candidate set -> eliminate if we find somewhere it's not supposed to be
+  class CandidateFinder : public ExprVisitor {
+   public:
+    void VisitBinding_(const VarBindingNode* binding) override {
+      ProcessBinding(binding->var, binding->value);
+    }
+
+    void VisitBinding_(const MatchCastNode* binding) override {
+      ProcessBinding(binding->var, binding->value);
+    }
+
+    void ProcessBinding(const Var& var, const Expr& value) {
+      if (var.as<DataflowVarNode>()) {
+        // add definition to binding map
+        binding_map_[Downcast<DataflowVar>(var)] = value;
+
+        // disqualify any dataflow vars in the RHS (since the LHS isn't an 
ordinary var)
+        DataflowVarGatherer gatherer;
+        auto disqualified = gatherer.Gather(value);
+        for (auto var : disqualified) {
+          disqualified_set_.insert(var);
+        }
+      } else {
+        // the LHS is an output, so disqualify if the RHS is not a single 
dataflow var
+        // or if the var has been output before
+        if (const auto* rhs_var = value.as<DataflowVarNode>()) {
+          if (output_vars_.count(GetRef<DataflowVar>(rhs_var))) {
+            disqualified_set_.insert(GetRef<DataflowVar>(rhs_var));
+          }
+          output_vars_.insert(GetRef<DataflowVar>(rhs_var));
+        } else {
+          DataflowVarGatherer gatherer;
+          auto disqualified = gatherer.Gather(value);
+          for (auto var : disqualified) {
+            disqualified_set_.insert(var);
+          }
+        }
+      }
+    }
+
+    std::unordered_map<DataflowVar, Expr, ObjectPtrHash, ObjectPtrEqual> 
FindCandidates(
+        const DataflowBlock& block) {
+      VisitBindingBlock(block);
+      // candidates: the output vars that are not in the disqualified set
+      std::unordered_map<DataflowVar, Expr, ObjectPtrHash, ObjectPtrEqual> ret;
+      for (auto var : output_vars_) {
+        if (!disqualified_set_.count(var)) {
+          ret[var] = binding_map_.at(var);
+        }
+      }
+      return ret;
+    }
+
+    std::unordered_map<DataflowVar, Expr, ObjectPtrHash, ObjectPtrEqual> 
binding_map_;
+    std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> 
disqualified_set_;
+    std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> 
output_vars_;
+  };
+
+  // given a candidate map (dataflow vars that should be eliminated mapped to 
their definitions),
+  // remove the bindings corresponding to those DF vars and replace the vars 
with their definitions
+  // when the appear on the RHS of a binding to an output var (non-DF var)
+  class BindingUpdater : public ExprMutator {
+   public:
+    explicit BindingUpdater(
+        const std::unordered_map<DataflowVar, Expr, ObjectPtrHash, 
ObjectPtrEqual>& candidate_map)
+        : candidate_map_(candidate_map) {}
+
+    void VisitBinding_(const VarBindingNode* binding) override {
+      // case 1: if the LHS is a DF node in the candidate map, erase the 
binding
+      if (binding->var.as<DataflowVarNode>() &&
+          candidate_map_.count(Downcast<DataflowVar>(binding->var))) {
+        return;
+      }
+      // case 2: if the RHS consists only of a DF node in the candidate map, 
replace the value
+      //   with the definition from the candidate map
+      if (!binding->var.as<DataflowVarNode>() && 
binding->value.as<DataflowVarNode>() &&
+          candidate_map_.count(Downcast<DataflowVar>(binding->value))) {
+        builder_->EmitNormalized(
+            VarBinding(binding->var, 
candidate_map_.at(Downcast<DataflowVar>(binding->value))));
+        return;
+      }
+      // case 3: if neither, use the default logic
+      ExprMutator::VisitBinding_(binding);
+    };
+
+    void VisitBinding_(const MatchCastNode* binding) {
+      // case 1: if the LHS is a DF node in the candidate map, erase the 
binding
+      if (binding->var.as<DataflowVarNode>() &&
+          candidate_map_.count(Downcast<DataflowVar>(binding->var))) {
+        return;
+      }
+      // case 2: if the RHS consists only of a DF node in the candidate map, 
replace the value
+      //   with the definition from the candidate map
+      if (!binding->var.as<DataflowVarNode>() && 
binding->value.as<DataflowVarNode>() &&
+          candidate_map_.count(Downcast<DataflowVar>(binding->value))) {
+        builder_->EmitNormalized(MatchCast(binding->var,
+                                           
candidate_map_.at(Downcast<DataflowVar>(binding->value)),
+                                           binding->struct_info));
+        return;
+      }
+      // case 3: if neither, use the default logic
+      ExprMutator::VisitBinding_(binding);
+    }
+
+    const std::unordered_map<DataflowVar, Expr, ObjectPtrHash, 
ObjectPtrEqual>& candidate_map_;
+  };
+
+  CandidateFinder finder;
+  auto candidate_map = finder.FindCandidates(block);
+  BindingUpdater updater(candidate_map);
+  auto new_block = updater.VisitBindingBlock(block);
+  return Downcast<DataflowBlock>(new_block);
+}
+
+namespace transform {
+
+Pass FoldDataflowBlockOutput() {
+  const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)>& pass_func =
+      [=](DataflowBlock b, IRModule m, PassContext pc) {
+        return relax::FoldDataflowBlockOutput(b);
+      };
+  return CreateDataflowBlockPass(pass_func, 1, "FoldDataflowBlockOutput", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.FoldDataflowBlockOutput")
+    .set_body_typed(FoldDataflowBlockOutput);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_fold_dataflow_block_output.py 
b/tests/python/relax/test_transform_fold_dataflow_block_output.py
new file mode 100644
index 0000000000..426e2b2710
--- /dev/null
+++ b/tests/python/relax/test_transform_fold_dataflow_block_output.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.relax.transform import FoldDataflowBlockOutput
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+
+def verify(input, expected):
+    tvm.ir.assert_structural_equal(FoldDataflowBlockOutput()(input), expected)
+
+
+def test_basic_example():
+    @tvm.script.ir_module
+    class Input:
+        @R.function
+        def main() -> R.Tensor((), "int32"):
+            with R.dataflow():
+                y = R.const(1)
+                n = y
+                R.output(n)
+            return n
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main() -> R.Tensor((), "int32"):
+            with R.dataflow():
+                n = R.const(1)
+                R.output(n)
+            return n
+
+    verify(Input, Expected)
+
+
+def test_match_cast():
+    @tvm.script.ir_module
+    class Input:
+        @R.function
+        def main() -> R.Tensor((), "int32"):
+            with R.dataflow():
+                y = R.const(1)
+                n = R.match_cast(y, R.Tensor((), "int32"))
+                R.output(n)
+            return n
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main() -> R.Tensor((), "int32"):
+            with R.dataflow():
+                n = R.match_cast(R.const(1), R.Tensor((), "int32"))
+                R.output(n)
+            return n
+
+    verify(Input, Expected)
+
+
+def test_unable_to_fold():
+    @tvm.script.ir_module
+    class MultipleUse:
+        @R.function
+        def main() -> R.Tensor((), "int32"):
+            with R.dataflow():
+                y = R.const(1)
+                # multiple uses -> cannot coalesce
+                m = R.add(y, y)
+                n = y
+                R.output(n)
+            return n
+
+    @tvm.script.ir_module
+    class ComplexExpr:
+        @R.function
+        def main() -> R.Tensor((), "int32"):
+            with R.dataflow():
+                y = R.const(1)
+                # y does not appear by itself -> cannot coalesce
+                n = R.add(y, y)
+                R.output(n)
+            return n
+
+    verify(MultipleUse, MultipleUse)
+    verify(ComplexExpr, ComplexExpr)
+
+
+def test_multiple_outputs():
+    @tvm.script.ir_module
+    class Input:
+        @R.function
+        def main() -> R.Tensor((), "int32"):
+            with R.dataflow():
+                x = R.const(1)
+                y = R.const(1)
+                z = R.const(1)
+                l = x
+                m = y
+                n = z
+                R.output(l, m, n)
+            return n
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main() -> R.Tensor((), "int32"):
+            with R.dataflow():
+                l = R.const(1)
+                m = R.const(1)
+                n = R.const(1)
+                R.output(l, m, n)
+            return n
+
+    verify(Input, Expected)
+
+
+def test_multiply_used_in_outputs():
+    # cannot fold in this case
+    @tvm.script.ir_module
+    class UsedInMultipleOutputs:
+        @R.function
+        def main() -> R.Tensor((), "int32"):
+            with R.dataflow():
+                x = R.const(1)
+                l = x
+                m = x
+                n = x
+                R.output(l, m, n)
+            return n
+
+    verify(UsedInMultipleOutputs, UsedInMultipleOutputs)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to