This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c5dc98c  [AutoSchedule] Support multiple cache read and fix bugs 
(#6686)
c5dc98c is described below

commit c5dc98c1c336610eb2763583078312b6f3e07c13
Author: Cody Yu <[email protected]>
AuthorDate: Sun Oct 18 23:56:44 2020 -0700

    [AutoSchedule] Support multiple cache read and fix bugs (#6686)
    
    * Add shape to DAG print
    
    * avoid useless cross-thread reduction
    
    * Fix stage order
    
    * support multiple cache_read
    
    * lint
    
    * fix
    
    * fix
    
    * address comment
    
    * fix ci
    
    * Trigger CI & Update doc strings
    
    Co-authored-by: Lianmin Zheng <[email protected]>
---
 include/tvm/auto_scheduler/compute_dag.h           |  7 +-
 python/tvm/auto_scheduler/compute_dag.py           | 23 ++++---
 src/auto_scheduler/compute_dag.cc                  | 77 +++++++++++++++++-----
 .../search_policy/sketch_policy_rules.cc           | 31 +++++----
 src/te/schedule/schedule_dataflow_rewrite.cc       |  9 +++
 .../python/unittest/test_auto_scheduler_common.py  | 14 ++++
 .../unittest/test_auto_scheduler_compute_dag.py    | 53 ++++++++++++++-
 7 files changed, 174 insertions(+), 40 deletions(-)

diff --git a/include/tvm/auto_scheduler/compute_dag.h 
b/include/tvm/auto_scheduler/compute_dag.h
index 553008a..6e67fef 100755
--- a/include/tvm/auto_scheduler/compute_dag.h
+++ b/include/tvm/auto_scheduler/compute_dag.h
@@ -200,11 +200,16 @@ class ComputeDAGNode : public Object {
  */
 class ComputeDAG : public ObjectRef {
  public:
-  /*! \brief The constructor.
+  /*! \brief Construct a DAG from a list of output tensors.
    * \param tensors `te::Tensor`s for a compute declaration.
    */
   TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);
 
+  /*! \brief Construct a DAG based on a schedule.
+   * \param sch `te::Schedule`s for a compute declaration.
+   */
+  TVM_DLL explicit ComputeDAG(const te::Schedule& sch);
+
   /*!
    * \brief Rewrite the layout of placeholder specified by attr 
`layout_free_placeholders`
    * according to the loop nest derived with `transform_steps`.
diff --git a/python/tvm/auto_scheduler/compute_dag.py 
b/python/tvm/auto_scheduler/compute_dag.py
index 0115dbc..4b1b264 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -47,22 +47,29 @@ class ComputeDAG(Object):
 
     Parameters
     ----------
-    compute : Union[List[Tensor], str]
+    compute : Union[List[Tensor], str, Schedule]
         Input/output tensors or workload key for a compute declaration.
     """
 
-    def __init__(self, compute):
-        if isinstance(compute, str):
-            compute = workload_key_to_tensors(compute)
-        elif isinstance(compute, list):
-            for item in compute:
+    def __init__(self, compute_or_sche):
+        if isinstance(compute_or_sche, str):
+            compute = workload_key_to_tensors(compute_or_sche)
+            sche = None
+        elif isinstance(compute_or_sche, list):
+            for item in compute_or_sche:
                 if not isinstance(item, tvm.te.Tensor):
                     raise ValueError("The input of ComputeDAG should be a list 
of Tensor")
+            compute = compute_or_sche
+            sche = None
+        elif isinstance(compute_or_sche, tvm.te.Schedule):
+            compute = None
+            sche = compute_or_sche
         else:
             raise ValueError(
-                "Invalid compute: " + compute + " . ComputeDAG expects a 
string or list of Tensor"
+                "Invalid compute type: %s. ComputeDAG expects string, list of 
Tensor, or Schedule"
+                % type(compute)
             )
-        self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute)
+        self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute, sche)
 
     def get_init_state(self):
         """Get the init state of this ComputeDAG.
diff --git a/src/auto_scheduler/compute_dag.cc 
b/src/auto_scheduler/compute_dag.cc
index 23b3817..3b0de97 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -662,7 +662,42 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   auto node = make_object<ComputeDAGNode>();
   node->tensors = std::move(tensors);
   node->access_analyzer = AccessAnalyzer(node->tensors);
-  node->ops = node->access_analyzer->ops_topo_order;
+
+  Array<te::Operation> out_ops;
+  for (const auto& op : node->access_analyzer->ops_topo_order) {
+    if (node->access_analyzer.IsOutput(op)) {
+      out_ops.push_back(op);
+    }
+  }
+  te::Schedule sch = te::create_schedule(out_ops);
+  for (auto stage : sch->stages) {
+    node->ops.push_back(stage->op);
+  }
+
+  node->flop_ct = FlopEstimator().EstimateFlop(node->ops);
+  node->init_state = State(node->ops);
+  data_ = std::move(node);
+}
+
+ComputeDAG::ComputeDAG(const te::Schedule& sch) {
+  auto node = make_object<ComputeDAGNode>();
+
+  // Initialize ops. Here we enforce the order of ops and stages are consistent
+  for (auto stage : sch->stages) {
+    node->ops.push_back(stage->op);
+  }
+
+  // Collect input and output tensors
+  Array<te::Tensor> tensors;
+  for (auto stage : sch->stages) {
+    if (stage->op->IsInstance<te::PlaceholderOpNode>() || stage->is_output) {
+      for (auto i = 0; i < stage->op->num_outputs(); ++i) {
+        tensors.push_back(stage->op.output(i));
+      }
+    }
+  }
+  node->tensors = std::move(tensors);
+  node->access_analyzer = AccessAnalyzer(node->tensors);
   node->flop_ct = FlopEstimator().EstimateFlop(node->ops);
   node->init_state = State(node->ops);
   data_ = std::move(node);
@@ -949,8 +984,6 @@ void ComputeDAG::RewriteLayout(const Array<Step>& 
transform_steps) {
         }
       }
 
-      p_dag->init_state = State(p_dag->ops);
-
       Array<te::Tensor> old_tensors = p_dag->tensors;
       ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite();
 
@@ -970,8 +1003,21 @@ void ComputeDAG::RewriteLayout(const Array<Step>& 
transform_steps) {
     }  // end for placeholder
   }    // end for stage
   p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors);
-  p_dag->ops = p_dag->access_analyzer->ops_topo_order;
+
+  Array<te::Operation> out_ops;
+  for (const auto& op : p_dag->access_analyzer->ops_topo_order) {
+    if (p_dag->access_analyzer.IsOutput(op)) {
+      out_ops.push_back(op);
+    }
+  }
+
+  p_dag->ops.clear();
+  te::Schedule sch = te::create_schedule(out_ops);
+  for (auto stage : sch->stages) {
+    p_dag->ops.push_back(stage->op);
+  }
   p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops);
+  p_dag->init_state = State(p_dag->ops);
 }
 
 std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
@@ -1144,17 +1190,7 @@ ComputeDAG ComputeDAG::ReplayAndGetDAG(const 
Array<Step>& transform_steps) const
   te::Schedule sch;
   Array<te::Tensor> old_tensors;
   std::tie(sch, old_tensors) = ApplySteps(transform_steps);
-
-  Array<te::Tensor> new_tensors;
-  for (auto stage : sch->stages) {
-    if (stage->op->IsInstance<te::PlaceholderOpNode>() || stage->is_output) {
-      for (auto i = 0; i < stage->op->num_outputs(); ++i) {
-        new_tensors.push_back(stage->op.output(i));
-      }
-    }
-  }
-
-  return ComputeDAG(new_tensors);
+  return ComputeDAG(sch);
 }
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -1259,9 +1295,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << ss.str();
     });
 
-TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG").set_body_typed([](Array<te::Tensor>
 tensors) {
-  return ComputeDAG(tensors);
-});
+TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG")
+    .set_body_typed([](Optional<Array<te::Tensor>> tensors, 
Optional<te::Schedule> sch) {
+      if (tensors) {
+        return ComputeDAG(tensors.value());
+      }
+      CHECK(sch) << "Both tensors and schedule are null";
+      return ComputeDAG(sch.value());
+    });
 
 TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState")
     .set_body_typed([](const ComputeDAG& dag, const State& state, const bool 
layout_rewrite) {
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc 
b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index b6ad4d3..1b965c9 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -151,9 +151,6 @@ SketchGenerationRule::ConditionKind 
RuleAddCacheRead::MeetCondition(const Sketch
 
   // Don't cache_read a stage if it has multiple consumers
   const std::set<int>& consumers = GetConsumers(task, state, stage_id);
-  if (consumers.size() != 1) {
-    return ConditionKind::kSkip;
-  }
 
   // Don't cache_read a stage if its consumer does not need multi-level tiling
   int target_stage_id = *consumers.begin();
@@ -179,16 +176,22 @@ std::vector<std::pair<State, int>> 
RuleAddCacheRead::Apply(const SketchPolicyNod
                                                            const State& state, 
int stage_id) const {
   const SearchTask& task = policy.search_task;
   const std::set<int>& consumers = GetConsumers(task, state, stage_id);
-  CHECK_EQ(consumers.size(), 1);
-  int target_stage_id = *consumers.begin();
   State tmp_s = state;
 
-  // Cache read add shared memory
-  int added_stage_id = tmp_s.cache_read(stage_id, "shared", {target_stage_id}, 
task->compute_dag);
-  target_stage_id++;
-  const auto& share_read_pos =
-      
GetLastReduceIteratorInOutermostReduceTile(tmp_s->stages[target_stage_id]);
-  tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos);
+  int target_stage_id_offset = 0;
+  for (int orig_target_stage_id : consumers) {
+    int target_stage_id = orig_target_stage_id + target_stage_id_offset;
+
+    // Cache read add shared memory
+    int added_stage_id = tmp_s.cache_read(stage_id, "shared", 
{target_stage_id}, task->compute_dag);
+    target_stage_id_offset++;
+    target_stage_id++;
+
+    const auto& share_read_pos =
+        
GetLastReduceIteratorInOutermostReduceTile(tmp_s->stages[target_stage_id]);
+    tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos);
+  }
+
   return {std::make_pair(tmp_s, stage_id)};
 }
 
@@ -332,7 +335,11 @@ SketchGenerationRule::ConditionKind 
RuleCrossThreadReduction::MeetCondition(
         GetCumulativeSpaceAndReductionLength(state->stages[stage_id]);
 
     if (NeedsMultilevelTiling(policy.search_task, state, stage_id)) {
-      // Do rfactor if we do not have enough parallelism on space iters
+      // Avoid rfactor if we have enough parallelism on space iters
+      if (cum_space_len > 
policy.search_task->hardware_params->max_threads_per_block) {
+        return ConditionKind::kSkip;
+      }
+
       return cum_space_len < cum_reduce_len ? ConditionKind::kApply : 
ConditionKind::kSkip;
     } else if (cum_reduce_len > 1) {
       // Try rfactor for other reduction operators
diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc 
b/src/te/schedule/schedule_dataflow_rewrite.cc
index f335f95..941817a 100644
--- a/src/te/schedule/schedule_dataflow_rewrite.cc
+++ b/src/te/schedule/schedule_dataflow_rewrite.cc
@@ -136,6 +136,15 @@ Tensor Schedule::cache_read(const Tensor& tensor, const 
std::string& scope,
   if (tensor->op->num_outputs() != 1) {
     os << ".v" << tensor->value_index;
   }
+
+  // when a schedule has multiple cache_read on the same tensor,
+  // we make sure their op names are unique. e.g., w.shared, w_d.shared, 
w_d_d.shared
+  for (auto pair : (*this)->stage_map) {
+    auto stage = pair.second;
+    if (stage->op->name == os.str() + "." + scope) {
+      os << ".d";
+    }
+  }
   os << "." << scope;
 
   std::unordered_map<Tensor, Tensor> vsub;
diff --git a/tests/python/unittest/test_auto_scheduler_common.py 
b/tests/python/unittest/test_auto_scheduler_common.py
index 880b112..764099e 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -53,6 +53,20 @@ def double_matmul_auto_scheduler_test(N):
     return [A, B, C, E]
 
 
+@auto_scheduler.register_workload
+def parallel_matmul_auto_scheduler_test(N):
+    """Two parallel matmuls with shared A."""
+    A = te.placeholder((N, N), name="A", dtype="float32")
+    B = te.placeholder((N, N), name="B", dtype="float32")
+    C = te.placeholder((N, N), name="C", dtype="float32")
+    k = te.reduce_axis((0, N), name="k")
+    D = te.compute((N, N), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), 
name="D")
+    k = te.reduce_axis((0, N), name="k")
+    E = te.compute((N, N), lambda i, j: te.sum(A[i][k] * C[k][j], axis=[k]), 
name="E")
+
+    return [A, B, C, D, E]
+
+
 # Test for register_workload with different name
 @auto_scheduler.register_workload("matmul_auto_scheduler_test_rename_1")
 def matmul_auto_scheduler_test_rename_0(N, M, K):
diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py 
b/tests/python/unittest/test_auto_scheduler_compute_dag.py
index a58f2ca..2ccedef 100644
--- a/tests/python/unittest/test_auto_scheduler_compute_dag.py
+++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py
@@ -21,7 +21,11 @@ import tvm
 from tvm import topi
 from tvm import auto_scheduler, te
 
-from test_auto_scheduler_common import get_tiled_matmul, 
matmul_auto_scheduler_test
+from test_auto_scheduler_common import (
+    get_tiled_matmul,
+    matmul_auto_scheduler_test,
+    parallel_matmul_auto_scheduler_test,
+)
 
 
 def test_apply_steps():
@@ -56,7 +60,54 @@ def test_estimate_flop():
     assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5
 
 
+def test_stage_order():
+    N = 512
+    A, B, C, D, E = parallel_matmul_auto_scheduler_test(N)
+    sch = te.create_schedule([D.op, E.op])
+    (D_local,) = sch.cache_write([D], "local")
+    (E_local,) = sch.cache_write([E], "local")
+    sch.cache_read(A, "shared", [D_local])
+    sch.cache_read(B, "shared", [D_local])
+    sch.cache_read(A, "shared", [E_local])
+    sch.cache_read(C, "shared", [E_local])
+
+    dag = auto_scheduler.ComputeDAG(sch)
+    stage_ops_1 = dag.get_init_state().stage_ops
+
+    # 3 placeholder, 4 x.shared, 2 {D,E}.local, 2 {D,E} compute
+    assert len(stage_ops_1) == 11
+
+    # Cache read stage should follow the source stage
+    for idx, op in enumerate(stage_ops_1):
+        if op.name == "A":
+            assert (
+                stage_ops_1[idx + 1].name == "A.d.shared"
+                and stage_ops_1[idx + 2].name == "A.shared"
+            )
+        elif op.name in ["B", "C"]:
+            assert stage_ops_1[idx + 1].name == "%s.shared" % op.name
+
+    # Apply the same schedule to Ansor state and it should have the same stage 
order
+    dag = auto_scheduler.ComputeDAG([A, B, C, D, E])
+    state = dag.get_init_state()
+
+    D_local = state.cache_write(D, "local")
+    E_local = state.cache_write(E, "local")
+    state.cache_read(A, "shared", [D_local])
+    state.cache_read(B, "shared", [D_local])
+    state.cache_read(A, "shared", [E_local])
+    state.cache_read(C, "shared", [E_local])
+
+    stage_ops_2 = state.stage_ops
+    assert len(stage_ops_1) == len(stage_ops_2)
+
+    # Cache read stage should follow the source stage
+    for op1, op2 in zip(stage_ops_1, stage_ops_2):
+        assert op1.name == op2.name
+
+
 if __name__ == "__main__":
     test_apply_steps()
     test_infer_bound()
     test_estimate_flop()
+    test_stage_order()

Reply via email to