This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fbfa926585 [Relax] Implement relax.transform.TopologicalSort (#16697)
fbfa926585 is described below

commit fbfa92658568428b27c6ee5762ab7fe2f7c0b415
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Sun Mar 17 17:17:18 2024 -0500

    [Relax] Implement relax.transform.TopologicalSort (#16697)
    
    * [Relax] Implement relax.transform.TopologicalSort
    
    This commit implements a utility `relax.transform.TopologicalSort`,
    which can re-order the bindings that occur in a
    `relax.DataflowBlock`.  This is not intended for use in a
    general-purpose optimization pipeline, but instead as a utility that
    may be used as needed in specific cases.  For example, normalization
    of unit tests that should not depend on the order of variable binding.
    
    * Update docstring according to review comment
---
 python/tvm/relax/transform/__init__.py             |   1 +
 python/tvm/relax/transform/transform.py            |  23 ++
 src/relax/transform/topological_sort.cc            | 377 +++++++++++++++++
 .../relax/test_transform_topological_sort.py       | 457 +++++++++++++++++++++
 4 files changed, 858 insertions(+)

diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index c3fb0f23be..7daa36cd2e 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -72,6 +72,7 @@ from .transform import (
     StaticPlanBlockMemory,
     ToMixedPrecision,
     ToNonDataflow,
+    TopologicalSort,
     UpdateParamStructInfo,
     UpdateVDevice,
     VMBuiltinLower,
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index e4c66558f5..9ef5133b71 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -233,6 +233,29 @@ def ToNonDataflow() -> tvm.ir.transform.Pass:
     return _ffi_api.ToNonDataflow()  # type: ignore
 
 
+def TopologicalSort(order="depth-first", direction="from-inputs") -> 
tvm.ir.transform.Pass:
+    """Sort bindings in relax.Dataflow blocks in the order specified
+
+    Parameters
+    ----------
+    order: str
+
+        The order in which bindings should be emitted.  Allowed values
+        are "depth-first" and "breadth-first".
+
+    direciton: str
+
+        The direction in which the sort should be performed.  Allowed
+        values are "from-inputs" and "from-outputs".
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+
+    """
+    return _ffi_api.TopologicalSort(order, direction)  # type: ignore
+
+
 def RemovePurityChecking() -> tvm.ir.transform.Pass:
     """Activate relax.force_pure on all pure functions in the module
     and unwrap all pure override ops into the normal versions.
diff --git a/src/relax/transform/topological_sort.cc 
b/src/relax/transform/topological_sort.cc
new file mode 100644
index 0000000000..a366ff4d12
--- /dev/null
+++ b/src/relax/transform/topological_sort.cc
@@ -0,0 +1,377 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/topological_sort.cc
+ * \brief Perform a topological sort of Dataflow blocks
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+
+#include <algorithm>
+#include <deque>
+#include <unordered_map>
+#include <unordered_set>
+#include <variant>
+#include <vector>
+
+namespace {
+struct InputNode {};
+struct OutputNode {};
+
+using DataflowNode = std::variant<InputNode, OutputNode, tvm::relax::Var>;
+
+bool operator==(const DataflowNode& a, const DataflowNode& b) {
+  if (const tvm::relax::Var* var_a = std::get_if<tvm::relax::Var>(&a)) {
+    if (const tvm::relax::Var* var_b = std::get_if<tvm::relax::Var>(&b)) {
+      const tvm::relax::VarNode* ptr_a = var_a->get();
+      const tvm::relax::VarNode* ptr_b = var_b->get();
+      return ptr_a == ptr_b;
+    }
+  }
+
+  return a.index() == b.index();
+}
+
+}  // namespace
+
+template <>
+struct std::hash<DataflowNode> {
+  std::size_t operator()(const DataflowNode& node) const noexcept {
+    if (const tvm::relax::Var* var = std::get_if<tvm::relax::Var>(&node)) {
+      const tvm::relax::VarNode* ptr = var->get();
+      std::hash<decltype(ptr)> hasher;
+      return hasher(ptr);
+    } else {
+      auto index = node.index();
+      std::hash<decltype(index)> hasher;
+      return hasher(index);
+    }
+  }
+};
+
+namespace tvm {
+namespace relax {
+
+namespace {
+
+enum class TraversalOrder {
+  DepthFirst,
+  BreadthFirst,
+};
+
+enum class StartingLocation {
+  FromInputs,
+  FromOutputs,
+};
+
+struct Dependencies {
+  std::vector<DataflowNode> binding_order;
+  std::unordered_map<DataflowNode, std::deque<DataflowNode>> downstream_users;
+  std::unordered_map<DataflowNode, std::deque<DataflowNode>> 
upstream_requirements;
+};
+
+class BindingOrderCollector : ExprVisitor {
+ public:
+  static Dependencies Collect(const Expr& expr) {
+    BindingOrderCollector visitor;
+    visitor.dependencies_.binding_order.push_back(InputNode());
+    visitor(expr);
+
+    // If there is a variable without any inputs (e.g. `R.const(1)`)
+    // or an unused variable, these must be handled somewhere, to
+    // ensure they are visited corrected.  It's easiest to perform the
+    // depth/breadth-first search if handled here, with `NullOpt`
+    // acting as a special value, so that the later traversal doesn't
+    // need to check for this special case.
+    std::vector<DataflowNode> zero_input_bindings;
+    std::vector<DataflowNode> unused_bindings;
+    for (const auto& var : visitor.dependencies_.binding_order) {
+      if (std::holds_alternative<Var>(var)) {
+        if (!visitor.dependencies_.upstream_requirements.count(var)) {
+          zero_input_bindings.push_back(var);
+        }
+        if (!visitor.dependencies_.downstream_users.count(var)) {
+          unused_bindings.push_back(var);
+        }
+      }
+    }
+
+    for (const auto& var : zero_input_bindings) {
+      visitor.dependencies_.upstream_requirements[var].push_back(InputNode());
+      visitor.dependencies_.downstream_users[InputNode()].push_back(var);
+    }
+    for (auto it = unused_bindings.rbegin(); it != unused_bindings.rend(); 
it++) {
+      const auto& var = *it;
+      
visitor.dependencies_.upstream_requirements[OutputNode()].push_front(var);
+      visitor.dependencies_.downstream_users[var].push_front(OutputNode());
+    }
+
+    visitor.dependencies_.binding_order.push_back(OutputNode());
+
+    return visitor.dependencies_;
+  }
+
+ private:
+  void VisitVarDef(const Var& var) override { 
dependencies_.binding_order.push_back(var); }
+
+  void VisitExpr_(const FunctionNode* op) override {
+    for (const auto& var : op->params) {
+      dependencies_.downstream_users[InputNode()].push_back(var);
+      dependencies_.upstream_requirements[var].push_back(InputNode());
+    }
+    VisitExpr(op->body);
+  }
+
+  void VisitBinding(const Binding& binding) override {
+    auto cache = current_binding_;
+    current_binding_ = binding->var;
+    ExprVisitor::VisitBinding(binding);
+    current_binding_ = cache;
+  }
+
+  void VisitExpr_(const VarNode* op) override {
+    Var upstream_requirement = GetRef<Var>(op);
+    auto downstream_user = current_binding_;
+
+    
dependencies_.downstream_users[upstream_requirement].push_back(downstream_user);
+    
dependencies_.upstream_requirements[downstream_user].push_back(upstream_requirement);
+  }
+
+  DataflowNode current_binding_ = OutputNode();
+  Dependencies dependencies_;
+};
+
+class TopologicalSorter : public ExprMutator {
+ public:
+  TopologicalSorter(TraversalOrder order, StartingLocation starting_location)
+      : order_(order), starting_location_(starting_location) {}
+
+  Expr VisitExpr_(const FunctionNode* op) override {
+    auto cached = dependencies_;
+    dependencies_ = BindingOrderCollector::Collect(GetRef<Expr>(op));
+
+    if (starting_location_ == StartingLocation::FromOutputs) {
+      std::reverse(dependencies_.binding_order.begin(), 
dependencies_.binding_order.end());
+    }
+    if (order_ == TraversalOrder::DepthFirst) {
+      for (auto& [upstream_var, downstream_vars] : 
dependencies_.downstream_users) {
+        std::reverse(downstream_vars.begin(), downstream_vars.end());
+      }
+    }
+
+    auto output = ExprMutator::VisitExpr_(op);
+    dependencies_ = cached;
+    return output;
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override {
+    auto block = GetRef<DataflowBlock>(op);
+
+    // A map from not-yet-defined variables to the binding that will
+    // define the variable.  Items are removed from this map as they
+    // are collected into `new_bindings`.
+    std::unordered_map<Var, Binding, ObjectPtrHash, ObjectPtrEqual> to_emit;
+    for (const auto& binding : block->bindings) {
+      to_emit.insert({binding->var, binding});
+    }
+
+    // A lookup map of `Var -> Var` edges, used to find the bindings
+    // that may be emitted next.  When starting at the function
+    // inputs, this is the map from variables to the downstream
+    // variables that depend on them.  When starting at the function
+    // outputs, this is the map from variables to the upstream
+    // variables that they require.
+    const auto& forward_edge_lookup = [&]() {
+      switch (starting_location_) {
+        case StartingLocation::FromInputs:
+          return dependencies_.downstream_users;
+        case StartingLocation::FromOutputs:
+          return dependencies_.upstream_requirements;
+        default:
+          LOG(FATAL) << "Invalid enum value for StartingLocation";
+      }
+    }();
+
+    // A lookup map of `Var -> Var` edges, used to determine if a
+    // binding can legally be emitted.  When starting at the function
+    // inputs, this is the map from variables to the upstream
+    // variables that they require.  (i.e. A variable may not be
+    // defined earlier than its last input.)  When starting at the
+    // function outputs, this is the map from variables to the
+    // downstream variables that depend on them.  (i.e. A variable may
+    // not be defined later than its first usage.)
+    const auto& backward_edge_lookup = [&]() {
+      switch (starting_location_) {
+        case StartingLocation::FromInputs:
+          return dependencies_.upstream_requirements;
+        case StartingLocation::FromOutputs:
+          return dependencies_.downstream_users;
+        default:
+          LOG(FATAL) << "Invalid enum value for StartingLocation";
+      }
+    }();
+
+    // The search state for nodes that must still be visited.  When
+    // doing a depth-first search, this is used as a stack, with
+    // `push_back` and `pop_back`.  When doing a breadth-first search,
+    // this is used as a queue, with `push_back` and `pop_front`.  A
+    // `std::deque` is used to support these two use cases.
+    auto deque = [&]() -> std::deque<DataflowNode> {
+      switch (starting_location_) {
+        case StartingLocation::FromInputs:
+          return {InputNode()};
+        case StartingLocation::FromOutputs:
+          return {OutputNode()};
+        default:
+          LOG(FATAL) << "Invalid enum value for StartingLocation";
+      }
+    }();
+
+    std::unordered_set<DataflowNode> visited;
+
+    // Given a variable that has just been defined (or NullOpt for the
+    // function's output), mark nodes as ready to visit.
+    auto push_descendents_to_stack = [&](const DataflowNode& var) {
+      auto it = forward_edge_lookup.find(var);
+      if (it == forward_edge_lookup.end()) {
+        return;
+      }
+      const auto& adjacent_vars = it->second;
+
+      for (const auto& adjacent_var : adjacent_vars) {
+        bool legal_to_output = [&]() -> bool {
+          if (visited.count(adjacent_var)) {
+            return false;
+          }
+
+          auto it = backward_edge_lookup.find(adjacent_var);
+          ICHECK(it != backward_edge_lookup.end());
+          const auto& prerequisites = it->second;
+          return std::all_of(prerequisites.begin(), prerequisites.end(),
+                             [&visited](const auto& var) { return 
visited.count(var); });
+        }();
+
+        if (legal_to_output) {
+          deque.push_back(adjacent_var);
+        }
+      }
+    };
+
+    std::vector<Binding> new_bindings;
+    while (deque.size()) {
+      DataflowNode visiting;
+      switch (order_) {
+        case TraversalOrder::DepthFirst: {
+          visiting = deque.back();
+          deque.pop_back();
+          break;
+        }
+        case TraversalOrder::BreadthFirst: {
+          visiting = deque.front();
+          deque.pop_front();
+          break;
+        }
+        default: {
+          LOG(FATAL) << "Invalid value for TraversalOrder: " << 
static_cast<int>(order_);
+        }
+      }
+
+      if (auto var = std::get_if<Var>(&visiting)) {
+        if (auto iter_emit = to_emit.find(*var); iter_emit != to_emit.end()) {
+          new_bindings.push_back(iter_emit->second);
+          to_emit.erase(iter_emit);
+        }
+      }
+      visited.insert(visiting);
+      push_descendents_to_stack(visiting);
+    }
+
+    ICHECK_EQ(to_emit.size(), 0) << "After visiting all bindings, "
+                                 << "no bindings should remain to emit.  "
+                                 << "However, bindings " <<
+        [&]() {
+          Array<Var> arr;
+          for (const auto& [var, binding] : to_emit) {
+            arr.push_back(var);
+          }
+          return arr;
+        }() << " still remain after emitting "
+                                 << Array<Binding>(new_bindings.begin(), 
new_bindings.end())
+                                        .Map([](const Binding& binding) { 
return binding->var; });
+
+    if (starting_location_ == StartingLocation::FromOutputs) {
+      std::reverse(new_bindings.begin(), new_bindings.end());
+    }
+
+    block.CopyOnWrite()->bindings = new_bindings;
+    return ExprMutator::VisitBindingBlock_(block.get());
+  }
+
+ private:
+  TraversalOrder order_;
+  StartingLocation starting_location_;
+  Dependencies dependencies_;
+};
+}  // namespace
+
+namespace transform {
+
+Pass TopologicalSort(TraversalOrder order, StartingLocation starting_location) 
{
+  auto pass_func = [=](Function func, IRModule, PassContext) {
+    TopologicalSorter mutator(order, starting_location);
+    return Downcast<Function>(mutator(func));
+  };
+  return relax::transform::CreateFunctionPass(pass_func, 0, "TopologicalSort", 
{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.TopologicalSort")
+    .set_body_typed([](String order_str, String direction_str) -> Pass {
+      TraversalOrder order = [&]() {
+        if (order_str == "depth-first") {
+          return TraversalOrder::DepthFirst;
+        } else if (order_str == "breadth-first") {
+          return TraversalOrder::BreadthFirst;
+        } else {
+          LOG(FATAL) << "ValueError: "
+                     << "Invalid value for traversal order: \"" << order_str 
<< "\".  "
+                     << "Allowed values are \"depth-first\" or 
\"breadth-first\"";
+        }
+      }();
+
+      StartingLocation starting_location = [&]() {
+        if (direction_str == "from-inputs") {
+          return StartingLocation::FromInputs;
+        } else if (direction_str == "from-outputs") {
+          return StartingLocation::FromOutputs;
+        } else {
+          LOG(FATAL) << "ValueError: "
+                     << "Invalid value for starting location: \"" << 
direction_str << "\".  "
+                     << "Allowed values are \"from-inputs\" or 
\"from-outputs\"";
+        }
+      }();
+
+      return TopologicalSort(order, starting_location);
+    });
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_topological_sort.py 
b/tests/python/relax/test_transform_topological_sort.py
new file mode 100644
index 0000000000..3f11c081fa
--- /dev/null
+++ b/tests/python/relax/test_transform_topological_sort.py
@@ -0,0 +1,457 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import ir as I, relax as R
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    def transform(self):
+        return tvm.relax.transform.TopologicalSort(
+            order=self.order,
+            direction=self.direction,
+        )
+
+
+class TestDepthFirstFromInputs(BaseCompare):
+    """Sort DataflowBlock bindings with DFS, starting from inputs
+
+    Starting with the inputs to the DataflowBlock, sort the variable
+    bindings according to their occurrence in a depth-first search.
+    """
+
+    order = "depth-first"
+    direction = "from-inputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                B2 = R.add(A, R.const(2))
+                C1 = R.add(A, B1)
+                C2 = R.add(A, B2)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                C1 = R.add(A, B1)
+                B2 = R.add(A, R.const(2))
+                C2 = R.add(A, B2)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+
+class TestDepthFirstFromInputWithConstant(BaseCompare):
+    """Topological sort must produce legal ordering.
+
+    Here, both `C1` and `C2` use the input tensor `A`.  However, they
+    also use the tensors `B1` and `B2`.  The bindings for `C1` and
+    `C2` may not be emitted until after all their inputs have been
+    emitted.
+
+    In addition, the bindings `B1` and `B2` do not require any of the
+    function inputs to compute.  If the DFS only used the function
+    parameters as the initial search nodes, it would fail to output
+    these variable bindings.
+    """
+
+    order = "depth-first"
+    direction = "from-inputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.const(1)
+                B2 = R.const(2)
+                C2 = R.add(A, B2)
+                C1 = R.add(A, B1)
+                D2 = R.add(A, C2)
+                D1 = R.add(A, C1)
+                E = R.add(D1, D2)
+                R.output(E)
+            return E
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.const(1)
+                C1 = R.add(A, B1)
+                D1 = R.add(A, C1)
+                B2 = R.const(2)
+                C2 = R.add(A, B2)
+                D2 = R.add(A, C2)
+                E = R.add(D1, D2)
+                R.output(E)
+            return E
+
+
+class TestDepthFirstFromInputWithMultipleInputs(BaseCompare):
+    """Use parameter order for deterministic sort
+
+    Here, both `C1` and `C2` use the input tensor `A`, as well as
+    input tensors `B1` and `B2`, respectively.  Since `B1` appears
+    before `B2`, `C1` should be sorted before `C2`.
+    """
+
+    order = "depth-first"
+    direction = "from-inputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor, B1: R.Tensor, B2: R.Tensor):
+            with R.dataflow():
+                C2 = R.add(A, B2)
+                C1 = R.add(A, B1)
+                D2 = R.add(A, C2)
+                D1 = R.add(A, C1)
+                E = R.add(D1, D2)
+                R.output(E)
+            return E
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor, B1: R.Tensor, B2: R.Tensor):
+            with R.dataflow():
+                C1 = R.add(A, B1)
+                D1 = R.add(A, C1)
+                C2 = R.add(A, B2)
+                D2 = R.add(A, C2)
+                E = R.add(D1, D2)
+                R.output(E)
+            return E
+
+
+class TestDepthFirstBreakTiesByExistingOrder(BaseCompare):
+    """If DFS is ambiguous, provide deterministic output
+
+    Here, both `B1` and `B2` use the input tensor `A`.  Since there
+    are no other inputs for `B1` or `B2`, they remain in the same
+    relative order as the input function, and `B1` is emitted before
+    `B2`.  The DFS then continues, placing `C1` immediately after
+    `B1`.
+    """
+
+    order = "depth-first"
+    direction = "from-inputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                B2 = R.add(A, R.const(2))
+                C2 = R.add(A, B2)
+                C1 = R.add(A, B1)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                C1 = R.add(A, B1)
+                B2 = R.add(A, R.const(2))
+                C2 = R.add(A, B2)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+
+class TestDepthFirstFromOutput(BaseCompare):
+    """Sort DataflowBlock bindings with DFS, starting from outputs
+
+    Starting with the outputs to the DataflowBlock, sort the variable
+    bindings according to their occurrence in a depth-first search.
+
+    Like `TestDepthFirstFromInputs`, but perform the search starting
+    at the output.
+    """
+
+    order = "depth-first"
+    direction = "from-outputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B2 = R.add(A, R.const(2))
+                B1 = R.add(A, R.const(1))
+                C2 = R.add(A, B2)
+                C1 = R.add(A, B1)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                C1 = R.add(A, B1)
+                B2 = R.add(A, R.const(2))
+                C2 = R.add(A, B2)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+
+class TestDepthFirstFromOutputTupleWithBinding(BaseCompare):
+    """A dataflow block may produce multiple outputs
+
+    If a dataflow block produces multiple outputs, the result should
+    be sorted according to the order in which the outputs are used.
+    Here, `C1` is used before `C2`, so the expressions required to
+    compute `C1` are moved before the expressions required to compute
+    `C2`.
+    """
+
+    order = "depth-first"
+    direction = "from-outputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B2 = R.add(A, R.const(2))
+                B1 = R.add(A, R.const(1))
+                C2 = R.add(A, B2)
+                C1 = R.add(A, B1)
+                R.output(C1, C2)
+            gv = (C1, C2)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                C1 = R.add(A, B1)
+                B2 = R.add(A, R.const(2))
+                C2 = R.add(A, B2)
+                R.output(C1, C2)
+            gv = (C1, C2)
+            return gv
+
+
+class TestDepthFirstFromOutputTupleWithoutBinding(BaseCompare):
+    """A dataflow block may produce multiple outputs
+
+    Like `TestDepthFirstFromOutputTupleWithBinding`, but the
+    DataflowBlock's outputs are not used as part of a variable
+    binding.  Because in-line tuples are not normalized to variable
+    bindings, this case must be handled explicitly.
+    """
+
+    order = "depth-first"
+    direction = "from-outputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B2 = R.add(A, R.const(2))
+                B1 = R.add(A, R.const(1))
+                C2 = R.add(A, B2)
+                C1 = R.add(A, B1)
+                R.output(C1, C2)
+            return (C1, C2)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                C1 = R.add(A, B1)
+                B2 = R.add(A, R.const(2))
+                C2 = R.add(A, B2)
+                R.output(C1, C2)
+            return (C1, C2)
+
+
+class TestDepthFirstFromOutputWithUnusedVariables(BaseCompare):
+    """Sort DataflowBlock bindings with DFS, starting from outputs
+
+    The variables `D1` and `D2` are unused, but must still appear
+    within the output DataflowBlock.
+
+    This is analogous to `TestDepthFirstFromInputWithConstant`.
+    Similar to how a DFS starting from the function inputs can
+    accidentally skip expressions with no inputs, a DFS starting from
+    the function outputs can accidentally skip expressions that do not
+    contribute to the output.
+    """
+
+    order = "depth-first"
+    direction = "from-outputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B2 = R.add(A, R.const(2))
+                B1 = R.add(A, R.const(1))
+                C2 = R.add(A, B2)
+                C1 = R.add(A, B1)
+                D1 = R.add(A, C1)
+                D2 = R.add(A, C2)
+                E = R.add(C1, C2)
+                R.output(E)
+            return E
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                C1 = R.add(A, B1)
+                D1 = R.add(A, C1)
+                B2 = R.add(A, R.const(2))
+                C2 = R.add(A, B2)
+                D2 = R.add(A, C2)
+                E = R.add(C1, C2)
+                R.output(E)
+            return E
+
+
+class TestDepthFirstFromInputWithUnusedParameters(BaseCompare):
+    """Sort DataflowBlock bindings with DFS, starting from inputs
+
+    Functions may accept parameters that are not used.
+    """
+
+    order = "depth-first"
+    direction = "from-inputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor, Unused: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                B2 = R.add(A, R.const(2))
+                C1 = R.add(A, B1)
+                C2 = R.add(A, B2)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor, Unused: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                C1 = R.add(A, B1)
+                B2 = R.add(A, R.const(2))
+                C2 = R.add(A, B2)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+
+class TestBreadthFirst(BaseCompare):
+    order = "breadth-first"
+    direction = "from-inputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                C1 = R.add(A, B1)
+                B2 = R.add(A, R.const(2))
+                C2 = R.add(A, B2)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(1))
+                B2 = R.add(A, R.const(2))
+                C1 = R.add(A, B1)
+                C2 = R.add(A, B2)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+
+class TestBreadthFirstBreakTiesByExistingOrder(BaseCompare):
+    order = "breadth-first"
+    direction = "from-inputs"
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B1 = R.add(A, R.const(2))
+                C1 = R.add(A, B1)
+                B2 = R.add(A, R.const(1))
+                C2 = R.add(A, B2)
+                D = R.add(C2, C1)
+                R.output(D)
+            return D
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor):
+            with R.dataflow():
+                B2 = R.add(A, R.const(2))
+                B1 = R.add(A, R.const(1))
+                C2 = R.add(A, B2)
+                C1 = R.add(A, B1)
+                D = R.add(C1, C2)
+                R.output(D)
+            return D
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to