This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c5dc98c [AutoSchedule] Support multiple cache read and fix bugs
(#6686)
c5dc98c is described below
commit c5dc98c1c336610eb2763583078312b6f3e07c13
Author: Cody Yu <[email protected]>
AuthorDate: Sun Oct 18 23:56:44 2020 -0700
[AutoSchedule] Support multiple cache read and fix bugs (#6686)
* Add shape to DAG print
* avoid useless cross-thread reduction
* Fix stage order
* support multiple cache_read
* lint
* fix
* fix
* address comment
* fix ci
* Trigger CI & Update doc strings
Co-authored-by: Lianmin Zheng <[email protected]>
---
include/tvm/auto_scheduler/compute_dag.h | 7 +-
python/tvm/auto_scheduler/compute_dag.py | 23 ++++---
src/auto_scheduler/compute_dag.cc | 77 +++++++++++++++++-----
.../search_policy/sketch_policy_rules.cc | 31 +++++----
src/te/schedule/schedule_dataflow_rewrite.cc | 9 +++
.../python/unittest/test_auto_scheduler_common.py | 14 ++++
.../unittest/test_auto_scheduler_compute_dag.py | 53 ++++++++++++++-
7 files changed, 174 insertions(+), 40 deletions(-)
diff --git a/include/tvm/auto_scheduler/compute_dag.h
b/include/tvm/auto_scheduler/compute_dag.h
index 553008a..6e67fef 100755
--- a/include/tvm/auto_scheduler/compute_dag.h
+++ b/include/tvm/auto_scheduler/compute_dag.h
@@ -200,11 +200,16 @@ class ComputeDAGNode : public Object {
*/
class ComputeDAG : public ObjectRef {
public:
- /*! \brief The constructor.
+ /*! \brief Construct a DAG from a list of output tensors.
* \param tensors `te::Tensor`s for a compute declaration.
*/
TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);
+ /*! \brief Construct a DAG based on a schedule.
+ * \param sch `te::Schedule`s for a compute declaration.
+ */
+ TVM_DLL explicit ComputeDAG(const te::Schedule& sch);
+
/*!
* \brief Rewrite the layout of placeholder specified by attr
`layout_free_placeholders`
* according to the loop nest derived with `transform_steps`.
diff --git a/python/tvm/auto_scheduler/compute_dag.py
b/python/tvm/auto_scheduler/compute_dag.py
index 0115dbc..4b1b264 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -47,22 +47,29 @@ class ComputeDAG(Object):
Parameters
----------
- compute : Union[List[Tensor], str]
+ compute : Union[List[Tensor], str, Schedule]
Input/output tensors or workload key for a compute declaration.
"""
- def __init__(self, compute):
- if isinstance(compute, str):
- compute = workload_key_to_tensors(compute)
- elif isinstance(compute, list):
- for item in compute:
+ def __init__(self, compute_or_sche):
+ if isinstance(compute_or_sche, str):
+ compute = workload_key_to_tensors(compute_or_sche)
+ sche = None
+ elif isinstance(compute_or_sche, list):
+ for item in compute_or_sche:
if not isinstance(item, tvm.te.Tensor):
raise ValueError("The input of ComputeDAG should be a list
of Tensor")
+ compute = compute_or_sche
+ sche = None
+ elif isinstance(compute_or_sche, tvm.te.Schedule):
+ compute = None
+ sche = compute_or_sche
else:
raise ValueError(
- "Invalid compute: " + compute + " . ComputeDAG expects a
string or list of Tensor"
+ "Invalid compute type: %s. ComputeDAG expects string, list of
Tensor, or Schedule"
+ % type(compute)
)
- self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute)
+ self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute, sche)
def get_init_state(self):
"""Get the init state of this ComputeDAG.
diff --git a/src/auto_scheduler/compute_dag.cc
b/src/auto_scheduler/compute_dag.cc
index 23b3817..3b0de97 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -662,7 +662,42 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
auto node = make_object<ComputeDAGNode>();
node->tensors = std::move(tensors);
node->access_analyzer = AccessAnalyzer(node->tensors);
- node->ops = node->access_analyzer->ops_topo_order;
+
+ Array<te::Operation> out_ops;
+ for (const auto& op : node->access_analyzer->ops_topo_order) {
+ if (node->access_analyzer.IsOutput(op)) {
+ out_ops.push_back(op);
+ }
+ }
+ te::Schedule sch = te::create_schedule(out_ops);
+ for (auto stage : sch->stages) {
+ node->ops.push_back(stage->op);
+ }
+
+ node->flop_ct = FlopEstimator().EstimateFlop(node->ops);
+ node->init_state = State(node->ops);
+ data_ = std::move(node);
+}
+
+ComputeDAG::ComputeDAG(const te::Schedule& sch) {
+ auto node = make_object<ComputeDAGNode>();
+
+ // Initialize ops. Here we enforce the order of ops and stages are consistent
+ for (auto stage : sch->stages) {
+ node->ops.push_back(stage->op);
+ }
+
+ // Collect input and output tensors
+ Array<te::Tensor> tensors;
+ for (auto stage : sch->stages) {
+ if (stage->op->IsInstance<te::PlaceholderOpNode>() || stage->is_output) {
+ for (auto i = 0; i < stage->op->num_outputs(); ++i) {
+ tensors.push_back(stage->op.output(i));
+ }
+ }
+ }
+ node->tensors = std::move(tensors);
+ node->access_analyzer = AccessAnalyzer(node->tensors);
node->flop_ct = FlopEstimator().EstimateFlop(node->ops);
node->init_state = State(node->ops);
data_ = std::move(node);
@@ -949,8 +984,6 @@ void ComputeDAG::RewriteLayout(const Array<Step>&
transform_steps) {
}
}
- p_dag->init_state = State(p_dag->ops);
-
Array<te::Tensor> old_tensors = p_dag->tensors;
ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite();
@@ -970,8 +1003,21 @@ void ComputeDAG::RewriteLayout(const Array<Step>&
transform_steps) {
} // end for placeholder
} // end for stage
p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors);
- p_dag->ops = p_dag->access_analyzer->ops_topo_order;
+
+ Array<te::Operation> out_ops;
+ for (const auto& op : p_dag->access_analyzer->ops_topo_order) {
+ if (p_dag->access_analyzer.IsOutput(op)) {
+ out_ops.push_back(op);
+ }
+ }
+
+ p_dag->ops.clear();
+ te::Schedule sch = te::create_schedule(out_ops);
+ for (auto stage : sch->stages) {
+ p_dag->ops.push_back(stage->op);
+ }
p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops);
+ p_dag->init_state = State(p_dag->ops);
}
std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
@@ -1144,17 +1190,7 @@ ComputeDAG ComputeDAG::ReplayAndGetDAG(const
Array<Step>& transform_steps) const
te::Schedule sch;
Array<te::Tensor> old_tensors;
std::tie(sch, old_tensors) = ApplySteps(transform_steps);
-
- Array<te::Tensor> new_tensors;
- for (auto stage : sch->stages) {
- if (stage->op->IsInstance<te::PlaceholderOpNode>() || stage->is_output) {
- for (auto i = 0; i < stage->op->num_outputs(); ++i) {
- new_tensors.push_back(stage->op.output(i));
- }
- }
- }
-
- return ComputeDAG(new_tensors);
+ return ComputeDAG(sch);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -1259,9 +1295,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ss.str();
});
-TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG").set_body_typed([](Array<te::Tensor>
tensors) {
- return ComputeDAG(tensors);
-});
+TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG")
+ .set_body_typed([](Optional<Array<te::Tensor>> tensors,
Optional<te::Schedule> sch) {
+ if (tensors) {
+ return ComputeDAG(tensors.value());
+ }
+ CHECK(sch) << "Both tensors and schedule are null";
+ return ComputeDAG(sch.value());
+ });
TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState")
.set_body_typed([](const ComputeDAG& dag, const State& state, const bool
layout_rewrite) {
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
index b6ad4d3..1b965c9 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc
@@ -151,9 +151,6 @@ SketchGenerationRule::ConditionKind
RuleAddCacheRead::MeetCondition(const Sketch
// Don't cache_read a stage if it has multiple consumers
const std::set<int>& consumers = GetConsumers(task, state, stage_id);
- if (consumers.size() != 1) {
- return ConditionKind::kSkip;
- }
// Don't cache_read a stage if its consumer does not need multi-level tiling
int target_stage_id = *consumers.begin();
@@ -179,16 +176,22 @@ std::vector<std::pair<State, int>>
RuleAddCacheRead::Apply(const SketchPolicyNod
const State& state,
int stage_id) const {
const SearchTask& task = policy.search_task;
const std::set<int>& consumers = GetConsumers(task, state, stage_id);
- CHECK_EQ(consumers.size(), 1);
- int target_stage_id = *consumers.begin();
State tmp_s = state;
- // Cache read add shared memory
- int added_stage_id = tmp_s.cache_read(stage_id, "shared", {target_stage_id},
task->compute_dag);
- target_stage_id++;
- const auto& share_read_pos =
-
GetLastReduceIteratorInOutermostReduceTile(tmp_s->stages[target_stage_id]);
- tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos);
+ int target_stage_id_offset = 0;
+ for (int orig_target_stage_id : consumers) {
+ int target_stage_id = orig_target_stage_id + target_stage_id_offset;
+
+ // Cache read add shared memory
+ int added_stage_id = tmp_s.cache_read(stage_id, "shared",
{target_stage_id}, task->compute_dag);
+ target_stage_id_offset++;
+ target_stage_id++;
+
+ const auto& share_read_pos =
+
GetLastReduceIteratorInOutermostReduceTile(tmp_s->stages[target_stage_id]);
+ tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos);
+ }
+
return {std::make_pair(tmp_s, stage_id)};
}
@@ -332,7 +335,11 @@ SketchGenerationRule::ConditionKind
RuleCrossThreadReduction::MeetCondition(
GetCumulativeSpaceAndReductionLength(state->stages[stage_id]);
if (NeedsMultilevelTiling(policy.search_task, state, stage_id)) {
- // Do rfactor if we do not have enough parallelism on space iters
+ // Avoid rfactor if we have enough parallelism on space iters
+ if (cum_space_len >
policy.search_task->hardware_params->max_threads_per_block) {
+ return ConditionKind::kSkip;
+ }
+
return cum_space_len < cum_reduce_len ? ConditionKind::kApply :
ConditionKind::kSkip;
} else if (cum_reduce_len > 1) {
// Try rfactor for other reduction operators
diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc
b/src/te/schedule/schedule_dataflow_rewrite.cc
index f335f95..941817a 100644
--- a/src/te/schedule/schedule_dataflow_rewrite.cc
+++ b/src/te/schedule/schedule_dataflow_rewrite.cc
@@ -136,6 +136,15 @@ Tensor Schedule::cache_read(const Tensor& tensor, const
std::string& scope,
if (tensor->op->num_outputs() != 1) {
os << ".v" << tensor->value_index;
}
+
+ // when a schedule has multiple cache_read on the same tensor,
+ // we make sure their op names are unique. e.g., w.shared, w_d.shared,
w_d_d.shared
+ for (auto pair : (*this)->stage_map) {
+ auto stage = pair.second;
+ if (stage->op->name == os.str() + "." + scope) {
+ os << ".d";
+ }
+ }
os << "." << scope;
std::unordered_map<Tensor, Tensor> vsub;
diff --git a/tests/python/unittest/test_auto_scheduler_common.py
b/tests/python/unittest/test_auto_scheduler_common.py
index 880b112..764099e 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -53,6 +53,20 @@ def double_matmul_auto_scheduler_test(N):
return [A, B, C, E]
+@auto_scheduler.register_workload
+def parallel_matmul_auto_scheduler_test(N):
+ """Two parallel matmuls with shared A."""
+ A = te.placeholder((N, N), name="A", dtype="float32")
+ B = te.placeholder((N, N), name="B", dtype="float32")
+ C = te.placeholder((N, N), name="C", dtype="float32")
+ k = te.reduce_axis((0, N), name="k")
+ D = te.compute((N, N), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]),
name="D")
+ k = te.reduce_axis((0, N), name="k")
+ E = te.compute((N, N), lambda i, j: te.sum(A[i][k] * C[k][j], axis=[k]),
name="E")
+
+ return [A, B, C, D, E]
+
+
# Test for register_workload with different name
@auto_scheduler.register_workload("matmul_auto_scheduler_test_rename_1")
def matmul_auto_scheduler_test_rename_0(N, M, K):
diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py
b/tests/python/unittest/test_auto_scheduler_compute_dag.py
index a58f2ca..2ccedef 100644
--- a/tests/python/unittest/test_auto_scheduler_compute_dag.py
+++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py
@@ -21,7 +21,11 @@ import tvm
from tvm import topi
from tvm import auto_scheduler, te
-from test_auto_scheduler_common import get_tiled_matmul,
matmul_auto_scheduler_test
+from test_auto_scheduler_common import (
+ get_tiled_matmul,
+ matmul_auto_scheduler_test,
+ parallel_matmul_auto_scheduler_test,
+)
def test_apply_steps():
@@ -56,7 +60,54 @@ def test_estimate_flop():
assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5
+def test_stage_order():
+ N = 512
+ A, B, C, D, E = parallel_matmul_auto_scheduler_test(N)
+ sch = te.create_schedule([D.op, E.op])
+ (D_local,) = sch.cache_write([D], "local")
+ (E_local,) = sch.cache_write([E], "local")
+ sch.cache_read(A, "shared", [D_local])
+ sch.cache_read(B, "shared", [D_local])
+ sch.cache_read(A, "shared", [E_local])
+ sch.cache_read(C, "shared", [E_local])
+
+ dag = auto_scheduler.ComputeDAG(sch)
+ stage_ops_1 = dag.get_init_state().stage_ops
+
+ # 3 placeholder, 4 x.shared, 2 {D,E}.local, 2 {D,E} compute
+ assert len(stage_ops_1) == 11
+
+ # Cache read stage should follow the source stage
+ for idx, op in enumerate(stage_ops_1):
+ if op.name == "A":
+ assert (
+ stage_ops_1[idx + 1].name == "A.d.shared"
+ and stage_ops_1[idx + 2].name == "A.shared"
+ )
+ elif op.name in ["B", "C"]:
+ assert stage_ops_1[idx + 1].name == "%s.shared" % op.name
+
+ # Apply the same schedule to Ansor state and it should have the same stage
order
+ dag = auto_scheduler.ComputeDAG([A, B, C, D, E])
+ state = dag.get_init_state()
+
+ D_local = state.cache_write(D, "local")
+ E_local = state.cache_write(E, "local")
+ state.cache_read(A, "shared", [D_local])
+ state.cache_read(B, "shared", [D_local])
+ state.cache_read(A, "shared", [E_local])
+ state.cache_read(C, "shared", [E_local])
+
+ stage_ops_2 = state.stage_ops
+ assert len(stage_ops_1) == len(stage_ops_2)
+
+ # Cache read stage should follow the source stage
+ for op1, op2 in zip(stage_ops_1, stage_ops_2):
+ assert op1.name == op2.name
+
+
if __name__ == "__main__":
test_apply_steps()
test_infer_bound()
test_estimate_flop()
+ test_stage_order()