This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d9e1659d2a [Unity][Transform] Elide redundant bindings of dataflow
vars (#15341)
d9e1659d2a is described below
commit d9e1659d2a92815341e96105f7e1dde6c9d797a6
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Wed Aug 2 18:10:13 2023 -0400
[Unity][Transform] Elide redundant bindings of dataflow vars (#15341)
This PR addresses the second issue in #15245. In particular, it handles the
case that can arise when using `CanonicalizeBindings` of having a dataflow var
whose only use is being bound to the output var, like so:
```python
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
y = R.const(1)
n = y
R.output(n)
return n
```
The only use for `y` is to be bound to `n`, which is the `DataflowBlock`'s
output.
This PR introduces adds a pass called `FoldDataflowBlockOutput` that
detects this case and eliminates the intermediate binding, leaving in this
example simply this:
```python
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
n = R.const(1)
R.output(n)
return n
```
* Coalesce intermediate dataflow vars in dead code elimination
* Fix typo in comment
* Move dataflow var elision to its own pass
* Remove redundant whitespace
---
include/tvm/relax/transform.h | 10 ++
python/tvm/relax/transform/transform.py | 15 ++
src/relax/transform/fold_dataflow_block_output.cc | 192 +++++++++++++++++++++
.../test_transform_fold_dataflow_block_output.py | 149 ++++++++++++++++
4 files changed, 366 insertions(+)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index dc2476f383..8d01262aab 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -189,6 +189,16 @@ TVM_DLL Pass BindParams(String func_name, Map<String,
runtime::NDArray> params);
*/
TVM_DLL Pass FoldConstant();
+/*!
+ * \brief If a dataflow var is used only in a binding to the dataflow block
+ * output var (i.e., a non-dataflow var), this removes the dataflow var
+ * and replaces the output var's binding with the dataflow var's direct
definition.
+ *
+ * This "cleans up" a situation that commonly arises when using
`CanonicalizeBindings`
+ * and `DeadCodeElimination`.
+ **/
+TVM_DLL Pass FoldDataflowBlockOutput();
+
/*!
* \brief Legalize high-level operator calls in Relax functions to call_tir
* with corresponding low-level TIR PrimFuncs.
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 73267c43ae..f7b0b4e9dd 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -453,6 +453,21 @@ def FoldConstant() -> tvm.ir.transform.Pass:
return _ffi_api.FoldConstant() # type: ignore
+def FoldDataflowBlockOutput() -> tvm.ir.transform.Pass:
+ """If a dataflow var is used only in a binding to the dataflow block
+ output var (i.e., a non-dataflow var), this removes the dataflow var
+ and replaces the output var's binding with the dataflow var's direct
definition.
+
+ This "cleans up" a situation that commonly arises when using
`CanonicalizeBindings`
+ and `DeadCodeElimination`.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ return _ffi_api.FoldDataflowBlockOutput() # type: ignore
+
+
def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
"""Annotate Op Pattern Kind for TIR functions
diff --git a/src/relax/transform/fold_dataflow_block_output.cc
b/src/relax/transform/fold_dataflow_block_output.cc
new file mode 100644
index 0000000000..dc6e8a68c5
--- /dev/null
+++ b/src/relax/transform/fold_dataflow_block_output.cc
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/fold_dataflow_block_output.cc
+ * \brief Pass that folds dataflow vars used only for binding
+ * the dataflow block output
+ * directly into the output
+ *
+ * If a dataflow var is used only in a binding to the dataflow block's
+ * output var (a non-dataflow var), this pass removes the dataflow var
+ * binding from the block and uses the dataflow var's definition
+ * directly in the output binding.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+// If a dataflow var is used *only* as the RHS of a binding to the dataflow
block output
+// (i.e., an ordinary var), then we can get rid of that dataflow var and bind
the DF var's
+// definition directly to the output.
+DataflowBlock FoldDataflowBlockOutput(const DataflowBlock& block) {
+ // helper: gather all dataflow vars inside an expression
+ class DataflowVarGatherer : public ExprVisitor {
+ public:
+ // ignore inner functions
+ void VisitExpr_(const FunctionNode* _) override {}
+
+ void VisitExpr_(const DataflowVarNode* var) override {
vars_.insert(GetRef<DataflowVar>(var)); }
+
+ std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual>
Gather(const Expr& expr) {
+ VisitExpr(expr);
+ return vars_;
+ }
+
+ std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> vars_;
+ };
+
+ // first we search for dataflow vars for which the condition is met:
+ // exclude if found anywhere other than RHS of a binding to an ordinary var
(or more than once)
+ // candidate set -> eliminate if we find somewhere it's not supposed to be
+ class CandidateFinder : public ExprVisitor {
+ public:
+ void VisitBinding_(const VarBindingNode* binding) override {
+ ProcessBinding(binding->var, binding->value);
+ }
+
+ void VisitBinding_(const MatchCastNode* binding) override {
+ ProcessBinding(binding->var, binding->value);
+ }
+
+ void ProcessBinding(const Var& var, const Expr& value) {
+ if (var.as<DataflowVarNode>()) {
+ // add definition to binding map
+ binding_map_[Downcast<DataflowVar>(var)] = value;
+
+ // disqualify any dataflow vars in the RHS (since the LHS isn't an
ordinary var)
+ DataflowVarGatherer gatherer;
+ auto disqualified = gatherer.Gather(value);
+ for (auto var : disqualified) {
+ disqualified_set_.insert(var);
+ }
+ } else {
+ // the LHS is an output, so disqualify if the RHS is not a single
dataflow var
+ // or if the var has been output before
+ if (const auto* rhs_var = value.as<DataflowVarNode>()) {
+ if (output_vars_.count(GetRef<DataflowVar>(rhs_var))) {
+ disqualified_set_.insert(GetRef<DataflowVar>(rhs_var));
+ }
+ output_vars_.insert(GetRef<DataflowVar>(rhs_var));
+ } else {
+ DataflowVarGatherer gatherer;
+ auto disqualified = gatherer.Gather(value);
+ for (auto var : disqualified) {
+ disqualified_set_.insert(var);
+ }
+ }
+ }
+ }
+
+ std::unordered_map<DataflowVar, Expr, ObjectPtrHash, ObjectPtrEqual>
FindCandidates(
+ const DataflowBlock& block) {
+ VisitBindingBlock(block);
+ // candidates: the output vars that are not in the disqualified set
+ std::unordered_map<DataflowVar, Expr, ObjectPtrHash, ObjectPtrEqual> ret;
+ for (auto var : output_vars_) {
+ if (!disqualified_set_.count(var)) {
+ ret[var] = binding_map_.at(var);
+ }
+ }
+ return ret;
+ }
+
+ std::unordered_map<DataflowVar, Expr, ObjectPtrHash, ObjectPtrEqual>
binding_map_;
+ std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual>
disqualified_set_;
+ std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual>
output_vars_;
+ };
+
+ // given a candidate map (dataflow vars that should be eliminated mapped to
their definitions),
+ // remove the bindings corresponding to those DF vars and replace the vars
with their definitions
+ // when the appear on the RHS of a binding to an output var (non-DF var)
+ class BindingUpdater : public ExprMutator {
+ public:
+ explicit BindingUpdater(
+ const std::unordered_map<DataflowVar, Expr, ObjectPtrHash,
ObjectPtrEqual>& candidate_map)
+ : candidate_map_(candidate_map) {}
+
+ void VisitBinding_(const VarBindingNode* binding) override {
+ // case 1: if the LHS is a DF node in the candidate map, erase the
binding
+ if (binding->var.as<DataflowVarNode>() &&
+ candidate_map_.count(Downcast<DataflowVar>(binding->var))) {
+ return;
+ }
+ // case 2: if the RHS consists only of a DF node in the candidate map,
replace the value
+ // with the definition from the candidate map
+ if (!binding->var.as<DataflowVarNode>() &&
binding->value.as<DataflowVarNode>() &&
+ candidate_map_.count(Downcast<DataflowVar>(binding->value))) {
+ builder_->EmitNormalized(
+ VarBinding(binding->var,
candidate_map_.at(Downcast<DataflowVar>(binding->value))));
+ return;
+ }
+ // case 3: if neither, use the default logic
+ ExprMutator::VisitBinding_(binding);
+ };
+
+ void VisitBinding_(const MatchCastNode* binding) {
+ // case 1: if the LHS is a DF node in the candidate map, erase the
binding
+ if (binding->var.as<DataflowVarNode>() &&
+ candidate_map_.count(Downcast<DataflowVar>(binding->var))) {
+ return;
+ }
+ // case 2: if the RHS consists only of a DF node in the candidate map,
replace the value
+ // with the definition from the candidate map
+ if (!binding->var.as<DataflowVarNode>() &&
binding->value.as<DataflowVarNode>() &&
+ candidate_map_.count(Downcast<DataflowVar>(binding->value))) {
+ builder_->EmitNormalized(MatchCast(binding->var,
+
candidate_map_.at(Downcast<DataflowVar>(binding->value)),
+ binding->struct_info));
+ return;
+ }
+ // case 3: if neither, use the default logic
+ ExprMutator::VisitBinding_(binding);
+ }
+
+ const std::unordered_map<DataflowVar, Expr, ObjectPtrHash,
ObjectPtrEqual>& candidate_map_;
+ };
+
+ CandidateFinder finder;
+ auto candidate_map = finder.FindCandidates(block);
+ BindingUpdater updater(candidate_map);
+ auto new_block = updater.VisitBindingBlock(block);
+ return Downcast<DataflowBlock>(new_block);
+}
+
+namespace transform {
+
+Pass FoldDataflowBlockOutput() {
+ const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)>& pass_func =
+ [=](DataflowBlock b, IRModule m, PassContext pc) {
+ return relax::FoldDataflowBlockOutput(b);
+ };
+ return CreateDataflowBlockPass(pass_func, 1, "FoldDataflowBlockOutput", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.FoldDataflowBlockOutput")
+ .set_body_typed(FoldDataflowBlockOutput);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_transform_fold_dataflow_block_output.py
b/tests/python/relax/test_transform_fold_dataflow_block_output.py
new file mode 100644
index 0000000000..426e2b2710
--- /dev/null
+++ b/tests/python/relax/test_transform_fold_dataflow_block_output.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.relax.transform import FoldDataflowBlockOutput
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+
+def verify(input, expected):
+ tvm.ir.assert_structural_equal(FoldDataflowBlockOutput()(input), expected)
+
+
+def test_basic_example():
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def main() -> R.Tensor((), "int32"):
+ with R.dataflow():
+ y = R.const(1)
+ n = y
+ R.output(n)
+ return n
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main() -> R.Tensor((), "int32"):
+ with R.dataflow():
+ n = R.const(1)
+ R.output(n)
+ return n
+
+ verify(Input, Expected)
+
+
+def test_match_cast():
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def main() -> R.Tensor((), "int32"):
+ with R.dataflow():
+ y = R.const(1)
+ n = R.match_cast(y, R.Tensor((), "int32"))
+ R.output(n)
+ return n
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main() -> R.Tensor((), "int32"):
+ with R.dataflow():
+ n = R.match_cast(R.const(1), R.Tensor((), "int32"))
+ R.output(n)
+ return n
+
+ verify(Input, Expected)
+
+
+def test_unable_to_fold():
+ @tvm.script.ir_module
+ class MultipleUse:
+ @R.function
+ def main() -> R.Tensor((), "int32"):
+ with R.dataflow():
+ y = R.const(1)
+ # multiple uses -> cannot coalesce
+ m = R.add(y, y)
+ n = y
+ R.output(n)
+ return n
+
+ @tvm.script.ir_module
+ class ComplexExpr:
+ @R.function
+ def main() -> R.Tensor((), "int32"):
+ with R.dataflow():
+ y = R.const(1)
+ # y does not appear by itself -> cannot coalesce
+ n = R.add(y, y)
+ R.output(n)
+ return n
+
+ verify(MultipleUse, MultipleUse)
+ verify(ComplexExpr, ComplexExpr)
+
+
+def test_multiple_outputs():
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def main() -> R.Tensor((), "int32"):
+ with R.dataflow():
+ x = R.const(1)
+ y = R.const(1)
+ z = R.const(1)
+ l = x
+ m = y
+ n = z
+ R.output(l, m, n)
+ return n
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main() -> R.Tensor((), "int32"):
+ with R.dataflow():
+ l = R.const(1)
+ m = R.const(1)
+ n = R.const(1)
+ R.output(l, m, n)
+ return n
+
+ verify(Input, Expected)
+
+
+def test_multiply_used_in_outputs():
+ # cannot fold in this case
+ @tvm.script.ir_module
+ class UsedInMultipleOutputs:
+ @R.function
+ def main() -> R.Tensor((), "int32"):
+ with R.dataflow():
+ x = R.const(1)
+ l = x
+ m = x
+ n = x
+ R.output(l, m, n)
+ return n
+
+ verify(UsedInMultipleOutputs, UsedInMultipleOutputs)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()