This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ce1fa8908f [TE] Record primitives of Schedule for visualization
(#14168)
ce1fa8908f is described below
commit ce1fa8908f626e58f245966dd0a2e2540b75dace
Author: Chun-I Tsai <[email protected]>
AuthorDate: Wed Mar 15 04:52:11 2023 +0800
[TE] Record primitives of Schedule for visualization (#14168)
* [ScheduleVisualization]
- Make Stage link to its Schedule via attach_sch.
- Add two array attributes, primitive_record and schedule_record to
Schedule.
- Create a new class, ScheduleContext, to record primitives.
- Register a pass config variable, keep_schedule_record to enable/disable
the recording.
- Add test cases for TEDD, build_module and schedule ops.
* [ScheduleVisualization]
* Fix grammar issues
* Rewrite unclear comments
* [ScheduleVisualization]
* Remove the wrong term, rebased in comments and variables.
---------
Co-authored-by: Joey Tsai <[email protected]>
---
include/tvm/te/schedule.h | 43 +++++++++++++++++++-
python/tvm/contrib/tedd.py | 27 ++++++++++++-
src/relay/backend/te_compiler.cc | 13 +++++-
src/te/schedule/schedule_dataflow_rewrite.cc | 14 ++++---
src/te/schedule/schedule_lang.cc | 51 ++++++++++++++++++++++-
tests/python/contrib/test_tedd.py | 58 ++++++++++++++++++++++++++-
tests/python/relay/test_build_module.py | 37 +++++++++++++++++
tests/python/unittest/test_te_schedule_ops.py | 53 ++++++++++++++++++++++++
8 files changed, 283 insertions(+), 13 deletions(-)
diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h
index 5d88793206..1b711a8370 100644
--- a/include/tvm/te/schedule.h
+++ b/include/tvm/te/schedule.h
@@ -62,8 +62,9 @@ class Stage : public ObjectRef {
/*!
* \brief create a new schedule for op.
* \param op The operator in the schedule
+ * \param sch The schedule which current stage belongs to
*/
- explicit Stage(Operation op);
+ explicit Stage(Operation op, const ScheduleNode* sch);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
@@ -445,6 +446,26 @@ class Schedule : public ObjectRef {
using ContainerType = ScheduleNode;
};
+/*!
+ * \brief Context helper to collect debug information of Schedule.
+ *
+ * Attach With<ScheduleContext>(schedule_instance, primitive_name)
+ * inside function body of schedule primitives to collect the
+ * snapshot of schedule status and corresponding primitive name
+ */
+class ScheduleContext {
+ private:
+ friend class With<ScheduleContext>;
+ ScheduleContext(const ScheduleNode* sch_node, String current_primitive_name);
+ void EnterWithScope();
+ void ExitWithScope();
+
+ /*! \brief Schedule instance to store information for debug */
+ Schedule sch_;
+ /*! \brief String representing which primitive has been applied to sch_ */
+ String current_primitive_name_;
+};
+
/*!
* \brief The schedule relation between IterVars
* can be Split, Fuse.
@@ -546,6 +567,8 @@ class StageNode : public Object {
IterVar attach_ivar;
/*! \brief The stage this node attaches to */
Stage attach_stage;
+ /*! \brief The schedule current stage is attached to */
+ const ScheduleNode* attach_sch;
/*! \brief The thread storage scope level of the stage */
std::string scope;
/*! \brief Whether this is an output stage */
@@ -615,12 +638,30 @@ class ScheduleNode : public Object {
* This is created on demand and can be invalidated.
*/
std::unordered_map<const Object*, Stage> op2stage_cache_;
+ /*!
+ * \brief list of all transformed schedules
+ * User can display the optimization strategy via TEDD step by step to check
+ * the order and effect of primitives. Set "te.keep_schedule_record" in
+ * PassContext config as true to enable recording.
+ */
+ Array<Schedule> schedule_record;
+ /*!
+ * \brief list of all applied primitive names.
+ */
+ Array<String> primitive_record;
+ /*!
+ * \brief Flag to keep schedule record or not.
+ */
+ Optional<Bool> keep_schedule_record;
void VisitAttrs(AttrVisitor* v) {
v->Visit("outputs", &outputs);
v->Visit("stages", &stages);
v->Visit("groups", &groups);
v->Visit("stage_map", &stage_map);
+ v->Visit("schedule_record", &schedule_record);
+ v->Visit("primitive_record", &primitive_record);
+ v->Visit("keep_schedule_record", &keep_schedule_record);
}
/*! \brief Initialize temp cache. */
diff --git a/python/tvm/contrib/tedd.py b/python/tvm/contrib/tedd.py
index a65f5e474a..aa423d8964 100644
--- a/python/tvm/contrib/tedd.py
+++ b/python/tvm/contrib/tedd.py
@@ -78,6 +78,27 @@ def insert_dot_id(sch):
return sch
+def itervar_equal(iv_a, iv_b):
+ """A helper method that compares the equality of two iterative variables"""
+ # Adopt the following method to assure the equality between two itervars.
+ # The plain comparison might fail (i.e. iv_a == iv_b) after the change of
+ # domain bounds from InferBound.
+ def _var_equal(v_a, v_b):
+ condtions = [
+ v_a.name == v_b.name,
+ v_a.dtype == v_b.dtype,
+ v_a.type_annotation == v_b.type_annotation,
+ ]
+ return all(c for c in condtions)
+
+ condtions = [
+ _var_equal(iv_a.var, iv_b.var),
+ iv_a.iter_type == iv_b.iter_type,
+ iv_a.thread_tag == iv_b.thread_tag,
+ ]
+ return all(c for c in condtions)
+
+
class ObjectManager:
"""A helper class tracking schedule objects, e.g. stage, IterVar,
relationship, and tensor, to their DOM path."""
@@ -88,6 +109,10 @@ class ObjectManager:
self.dict[stage] = [stage_idx]
for itervar_idx, itervar in enumerate(stage.all_iter_vars):
self.dict[itervar] = [stage_idx, itervar_idx]
+ # the itervars of leaf should also be mapped to the original
one
+ for leaf_iv in stage.leaf_iter_vars:
+ if itervar_equal(leaf_iv, itervar):
+ self.dict[leaf_iv] = [stage_idx, itervar_idx]
for rel_idx, rel in enumerate(stage.relations):
self.dict[rel] = [stage_idx, rel_idx]
for tensor_idx in range(stage.op.num_outputs):
@@ -289,7 +314,7 @@ def dump_json(sch, need_range):
def get_leaf_itervar_index(itervar, leaf_iv):
for leaf_index, ivar in enumerate(leaf_iv):
- if ivar == itervar:
+ if itervar_equal(ivar, itervar):
return leaf_index
return -1
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index e20e0c9429..ce47be361e 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -700,7 +700,8 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
*/
Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar&
prim_fn_var,
Array<Expr> args, Span span, const Target& target,
- const Map<GlobalVar, BaseFunc>& lowered_functions) {
+ const Map<GlobalVar, BaseFunc>& lowered_functions,
+ const te::Schedule& sch = {}) {
auto opt_compiler = original_function->GetAttr<String>(attr::kCompiler);
// Add some metadata on top of the *original function* and invoke the
callback so it can
@@ -730,6 +731,10 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
prim_fn_var);
func_with_metadata = WithAttr(func_with_metadata, "prim_funcs",
prim_fns);
func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
target);
+ // Store generated Schedules of operator
+ if (sch.defined() && sch->keep_schedule_record) {
+ func_with_metadata = WithAttr(func_with_metadata, "schedule", sch);
+ }
this->process_fn_(func_with_metadata);
} else {
const auto* function_node = original_function.as<FunctionNode>();
@@ -738,6 +743,10 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var",
prim_fn_var);
func_with_metadata = WithAttr(func_with_metadata, "prim_funcs",
prim_fns);
func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget,
target);
+ // Store generated Schedules of operator
+ if (sch.defined() && sch->keep_schedule_record) {
+ func_with_metadata = WithAttr(func_with_metadata, "schedule", sch);
+ }
this->process_fn_(func_with_metadata);
}
@@ -926,7 +935,7 @@ class LowerTensorExprMutator : public
DeviceAwareExprMutator {
CachedFunc cfunc = compiler_->Lower(key);
ICHECK(cfunc.defined());
return MakeLoweredCall(primitive_func, cfunc->prim_fn_var,
std::move(new_args),
- call_node->span, target, cfunc->funcs->functions);
+ call_node->span, target, cfunc->funcs->functions,
cfunc->schedule);
}
}
diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc
b/src/te/schedule/schedule_dataflow_rewrite.cc
index c1741e9e4e..c38c5a5c80 100644
--- a/src/te/schedule/schedule_dataflow_rewrite.cc
+++ b/src/te/schedule/schedule_dataflow_rewrite.cc
@@ -174,10 +174,12 @@ Tensor Schedule::cache_read(const Tensor& tensor, const
std::string& scope,
Array<Stage>& stages = (*this)->stages;
Stage op_stage = operator[](tensor->op);
size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage);
- Stage cache_stage = Stage(cache->op);
- cache_stage.set_scope(scope);
+ Stage cache_stage = Stage(cache->op, this->operator->());
ICHECK_LT(pos, stages.size());
stages.insert(stages.begin() + pos + 1, cache_stage);
+ // in order to obtain correct copy on schedule_record,
+ // make sure "set_scope" primitive is applied after stage being added
+ cache_stage.set_scope(scope);
(*this)->stage_map.Set(cache->op, cache_stage);
// Update group
cache_stage->group = op_stage->group;
@@ -266,10 +268,12 @@ Array<Tensor> ReplaceOriginalOp(Schedule sch, Stage
orig_stage, const std::strin
// create schedule for new cached stage.
Array<Stage>& stages = sch->stages;
size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage);
- Stage cache_stage = Stage(cache_op);
- cache_stage.set_scope(scope);
+ Stage cache_stage = Stage(cache_op, sch.operator->());
ICHECK_LT(pos, stages.size());
stages.insert(stages.begin() + pos, cache_stage);
+ // in order to obtain correct copy on schedule_record,
+ // make sure "set_scope" primitive is applied after stage being added
+ cache_stage.set_scope(scope);
sch->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
@@ -892,7 +896,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const
IterVar& axis, int f
Operation factor_op(n);
Array<Stage>& stages = (*this)->stages;
size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage);
- Stage factor_stage = Stage(factor_op);
+ Stage factor_stage = Stage(factor_op, this->operator->());
factor_stage->relations = rels;
ICHECK_LT(stage_pos, stages.size());
stages.insert(stages.begin() + stage_pos, factor_stage);
diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc
index e8f4f65eb6..56fe0cfc65 100644
--- a/src/te/schedule/schedule_lang.cc
+++ b/src/te/schedule/schedule_lang.cc
@@ -21,6 +21,7 @@
* \file schedule_lang.cc
*/
#include <dmlc/thread_local.h>
+#include <tvm/ir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
@@ -91,7 +92,7 @@ void SplitHelper(StageNode* self, IterVar parent, PrimExpr
factor, PrimExpr npar
leaf_vars.insert(leaf_vars.begin() + pos, outer);
}
-Stage::Stage(Operation op) {
+Stage::Stage(Operation op, const ScheduleNode* sch) {
auto n = make_object<StageNode>();
n->op = op;
n->origin_op = op;
@@ -106,6 +107,7 @@ Stage::Stage(Operation op) {
} else {
n->leaf_iter_vars = clean;
}
+ n->attach_sch = sch;
data_ = std::move(n);
}
@@ -124,11 +126,13 @@ Stage Stage::GetAttachSpec() const {
}
Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
(*this)->scope = scope;
return *this;
}
Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at
for scan updates";
// Group constraint checking.
Stage group = (*this)->group;
@@ -156,18 +160,21 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) {
// NOLINT(*)
}
Stage& Stage::compute_inline() { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at
for scan updates";
(*this)->attach_type = kInline;
return *this;
}
Stage& Stage::compute_root() { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at
for scan updates";
(*this)->attach_type = kGroupRoot;
return *this;
}
Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
StageNode* self = operator->();
ICHECK(ivar->iter_type == kDataPar || ivar->iter_type == kCommReduce)
<< "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread";
@@ -194,6 +201,7 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { //
NOLINT(*)
}
Stage& Stage::env_threads(Array<IterVar> threads) {
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
StageNode* self = operator->();
ICHECK(self->op.defined() && self->op.as<ScanOpNode>())
<< "env_threads is only valid for composite ops such as ScanOp";
@@ -211,6 +219,7 @@ Stage& Stage::env_threads(Array<IterVar> threads) {
}
Stage& Stage::set_store_predicate(PrimExpr predicate) {
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
StageNode* self = operator->();
self->store_predicate = predicate;
return *this;
@@ -218,17 +227,20 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) {
Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer,
IterVar* p_inner) { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
return *this;
}
Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar*
p_outer,
IterVar* p_inner) { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
return *this;
}
Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { //
NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
StageNode* self = operator->();
ICHECK(outer->iter_type == kDataPar || outer->iter_type == kCommReduce ||
outer->iter_type == kOrdered)
@@ -264,6 +276,7 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar*
p_target) { // NOLINT
}
Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { //
NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
if (axes.size() != 0) {
IterVar fused = axes[0];
for (size_t i = 1; i < axes.size(); ++i) {
@@ -287,6 +300,7 @@ Stage& Stage::fuse(const Array<IterVar>& axes, IterVar*
p_target) { // NOLINT(*
}
Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
std::unordered_set<IterVar> seen_var;
StageNode* self = operator->();
for (IterVar iv : order) {
@@ -347,6 +361,7 @@ inline void SetAttrIterType(StageNode* self, IterVar var,
IterVarType iter_type)
}
Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
ICHECK(var->iter_type == kDataPar || var->iter_type == kOpaque ||
var->iter_type == kUnrolled ||
var->iter_type == kVectorized || var->iter_type == kTensorized ||
var->iter_type == kParallelized)
@@ -356,6 +371,7 @@ Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
}
Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) {
n->iter_type = kTensorized;
n->tensor_intrin = f;
@@ -364,11 +380,13 @@ Stage& Stage::tensorize(IterVar var, TensorIntrin f) {
// NOLINT(*)
}
Stage& Stage::unroll(IterVar var) { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
SetAttrIterType(operator->(), var, kUnrolled);
return *this;
}
Stage& Stage::parallel(IterVar var) { // NOLINT(*)
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
SetAttrIterType(operator->(), var, kParallelized);
return *this;
}
@@ -380,6 +398,7 @@ Stage& Stage::pragma(IterVar var, const std::string&
pragma_type,
} else if (pragma_type == "vectorize") {
this->vectorize(var);
} else {
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
UpdateIterVarAttr(operator->(), var, [pragma_type,
pragma_value](IterVarAttrNode* n) {
n->pragma_keys.push_back(tir::StringImm(pragma_type));
n->pragma_values.push_back(pragma_value);
@@ -389,6 +408,7 @@ Stage& Stage::pragma(IterVar var, const std::string&
pragma_type,
}
Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) {
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
StageNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
@@ -407,6 +427,7 @@ Stage& Stage::prefetch(const Tensor& tensor, IterVar var,
PrimExpr offset) {
}
Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
StageNode* self = operator->();
UpdateIterVarAttr(
self, axis,
@@ -419,6 +440,7 @@ Stage& Stage::storage_align(IterVar axis, int factor, int
offset) {
}
Stage& Stage::double_buffer() {
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
StageNode* self = operator->();
ICHECK(!self->is_output) << "Cannot apply double buffer on output";
self->double_buffer = true;
@@ -426,6 +448,7 @@ Stage& Stage::double_buffer() {
}
Stage& Stage::rolling_buffer() {
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
StageNode* self = operator->();
ICHECK(!self->is_output) << "Cannot apply rolling buffer on output";
self->rolling_buffer = true;
@@ -434,6 +457,7 @@ Stage& Stage::rolling_buffer() {
Stage& Stage::transform_layout(const Array<Var>& initial_indices,
const Array<PrimExpr>& final_indices,
Array<IterVar>* out_iter_vars) {
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
StageNode* self = operator->();
IndexMap map(initial_indices, final_indices);
self->layout_transforms.push_back(map);
@@ -501,6 +525,7 @@ Stage& Stage::transform_layout(const Array<Var>&
initial_indices,
}
Stage& Stage::set_axis_separators(const Array<IntImm>& axis_separators) {
+ With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
StageNode* self = operator->();
self->axis_separators = axis_separators;
return *this;
@@ -630,6 +655,7 @@ Stage Schedule::create_group(const Array<Tensor>& outputs,
const Array<Tensor>&
}
// Create the new group stage.
Stage gstage(make_object<StageNode>());
+ gstage->attach_sch = this->operator->();
gstage->group = parent_group;
if (parent_group.defined()) {
++parent_group->num_child_stages;
@@ -718,6 +744,8 @@ bool ScheduleNode::Contain(const Operation& op) const {
return stage_map.find(op) != stage_map.end();
}
+TVM_REGISTER_PASS_CONFIG_OPTION("te.keep_schedule_record", Bool);
+
Schedule::Schedule(Array<Operation> ops) {
auto n = make_object<ScheduleNode>();
data_ = n;
@@ -730,7 +758,7 @@ Schedule::Schedule(Array<Operation> ops) {
output_set.insert(x);
}
for (Operation op : post_order) {
- Stage stage(op);
+ Stage stage(op, this->operator->());
stage->is_output = output_set.count(op) != 0;
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
@@ -754,6 +782,25 @@ Schedule::Schedule(Array<Operation> ops) {
}
}
}
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+ n->keep_schedule_record =
pass_ctx->GetConfig<Bool>("te.keep_schedule_record", Bool(false));
+ if (n->keep_schedule_record.value()) {
+ // push plain schedule as the very first one
+ n->schedule_record.push_back(copy());
+ n->primitive_record.push_back("vanilla");
+ }
+}
+
+ScheduleContext::ScheduleContext(const ScheduleNode* sch_node, String
current_primitive_name)
+ : sch_(GetRef<Schedule>(sch_node)),
current_primitive_name_(current_primitive_name) {}
+
+void ScheduleContext::EnterWithScope() {}
+
+void ScheduleContext::ExitWithScope() {
+ if (sch_.defined() && sch_->keep_schedule_record.value()) {
+ sch_->schedule_record.push_back(sch_.copy());
+ sch_->primitive_record.push_back(current_primitive_name_);
+ }
}
Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor,
PrimExpr nparts) {
diff --git a/tests/python/contrib/test_tedd.py
b/tests/python/contrib/test_tedd.py
index 373fb14d36..c1af9f6825 100644
--- a/tests/python/contrib/test_tedd.py
+++ b/tests/python/contrib/test_tedd.py
@@ -16,8 +16,12 @@
# under the License.
import re
+import tvm
from tvm import te
from tvm import topi
+from tvm import relay
+from tvm.relay import testing
+from tvm.relay.backend import Runtime, Executor
def findany(pattern, str):
@@ -79,8 +83,8 @@ def test_itervar_relationship_graph():
findany(r"subgraph cluster_Stage_0", str)
findany(r"subgraph cluster_Stage_1", str)
# Check itervars and their types
- findany(r"\(kDataPar\)\<br/\>range\(min=0, ext=n\)", str)
- findany(r"\(kCommReduce\)\<br/\>range\(min=0, ext=m\)", str)
+ findany(r"\(kDataPar\)\<br/\>T.Range\(0, n\)", str)
+ findany(r"\(kCommReduce\)\<br/\>T.Range\(0, m\)", str)
# Check the split node
findany(r"Split_Relation_1_0 +.+\>Split", str)
# Check all edges to/from the split node
@@ -144,7 +148,57 @@ def test_schedule_tree():
verify()
[email protected]_llvm
+def test_tedd_with_schedule_record():
+ """Test to build a nn model and check if all schedules could be
generated"""
+
+ def check_schedule(executor):
+ from tvm.contrib import tedd
+
+ error = {}
+ for func_name, func_meta in executor.function_metadata.items():
+ # check converted op only
+ if "main" not in func_name:
+ primfunc = list(func_meta.relay_primfuncs.values())[0]
+ schs = primfunc.attrs["schedule"].schedule_record
+ for index in range(len(schs)):
+ try:
+ sch = schs[index].normalize()
+ tedd.viz_dataflow_graph(sch, False, "", True)
+ tedd.viz_itervar_relationship_graph(sch, False, "",
True)
+ tedd.viz_schedule_tree(sch, False, "", True)
+ except:
+ if func_name not in error:
+ error[func_name] = []
+ error[func_name].append(index)
+
+ assert error == {}, str(error)
+
+ if checkdependency():
+ relay_mod, params = testing.mobilenet.get_workload(batch_size=1,
dtype="float32")
+ target_llvm = tvm.target.Target("llvm")
+ config = {"te.keep_schedule_record": True}
+
+ with tvm.transform.PassContext(opt_level=3, config=config):
+ aot_executor_factory = relay.build(
+ relay_mod,
+ target_llvm,
+ runtime=Runtime("cpp"),
+ executor=Executor("aot"),
+ params=params,
+ )
+ graph_executor_factory = relay.build(
+ relay_mod,
+ target_llvm,
+ params=params,
+ )
+
+ check_schedule(aot_executor_factory)
+ check_schedule(graph_executor_factory)
+
+
if __name__ == "__main__":
test_dfg()
test_itervar_relationship_graph()
test_schedule_tree()
+ test_tedd_with_schedule_record()
diff --git a/tests/python/relay/test_build_module.py
b/tests/python/relay/test_build_module.py
index 5cfc27330a..b1146743ee 100644
--- a/tests/python/relay/test_build_module.py
+++ b/tests/python/relay/test_build_module.py
@@ -21,6 +21,7 @@ import tvm
import tvm.testing
from tvm import relay
from tvm.target.target import Target
+from tvm.relay import testing
from tvm.relay.backend import Runtime, Executor, graph_executor_codegen
@@ -62,5 +63,41 @@ def test_build_relay_graph_():
build_graph(add((1, 8), "float32"), tvm.target.Target("llvm"))
[email protected]_llvm
+def test_schedule_record():
+ """Test to build a nn model and get schedule_record from build_module"""
+
+ def check_schedule(executor):
+ for func_name, func_meta in executor.function_metadata.items():
+ # check converted op only
+ if "main" not in func_name:
+ primfunc = list(func_meta.relay_primfuncs.values())[0]
+ # make sure schedule is well-stored in function metadata
+ assert "schedule" in primfunc.attrs
+ sch = primfunc.attrs["schedule"]
+ assert len(sch.schedule_record) == len(sch.primitive_record)
+
+ relay_mod, params = testing.mobilenet.get_workload(batch_size=1,
dtype="float32")
+ target_llvm = tvm.target.Target("llvm")
+ config = {"te.keep_schedule_record": True}
+
+ with tvm.transform.PassContext(opt_level=3, config=config):
+ aot_executor_factory = relay.build(
+ relay_mod,
+ target_llvm,
+ runtime=Runtime("cpp"),
+ executor=Executor("aot"),
+ params=params,
+ )
+ graph_executor_factory = relay.build(
+ relay_mod,
+ target_llvm,
+ params=params,
+ )
+
+ check_schedule(aot_executor_factory)
+ check_schedule(graph_executor_factory)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/unittest/test_te_schedule_ops.py
b/tests/python/unittest/test_te_schedule_ops.py
index f85cdc6196..1ff0297539 100644
--- a/tests/python/unittest/test_te_schedule_ops.py
+++ b/tests/python/unittest/test_te_schedule_ops.py
@@ -614,6 +614,57 @@ def test_local_stage_predicate2():
assert any(collect_visit(lowered_body, visit_stmt))
+def test_schedule_record_gemm():
+ with tvm.transform.PassContext(config={"te.keep_schedule_record": True}):
+ M, K, N = 1024, 1024, 1024
+ k = te.reduce_axis((0, K), "k")
+ A = te.placeholder((M, K), name="A")
+ B = te.placeholder((K, N), name="B")
+ C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k),
name="C")
+ s = te.create_schedule(C.op)
+ # currently there are no other applied primitives
+ # size of schedule record is expected to be 1 (vanilla schedule)
+ assert len(s.schedule_record) == 1
+ # apply sequential optimizatoin primitives
+ block_size, factor = 32, 8
+ # tile -> split + split + reorder
+ mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], block_size,
block_size)
+ ko, ki = s[C].split(k, factor=factor)
+ s[C].reorder(mo, ko, no, mi, ki, ni)
+ s[C].vectorize(ni)
+ s[C].parallel(mo)
+ assert len(s.schedule_record) == 8
+ # compare primitive names
+ expected_names = [
+ "vanilla",
+ "split",
+ "split",
+ "reorder",
+ "split",
+ "reorder",
+ "vectorize",
+ "parallel",
+ ]
+ for i in range(len(s.schedule_record)):
+ assert s.primitive_record[i] == expected_names[i]
+
+
+def test_schedule_record_misc():
+ s = te.create_schedule([])
+ # size of schedule record is expected to be 0 (no storing behavior)
+ assert len(s.schedule_record) == 0
+
+ with tvm.transform.PassContext(config={"te.keep_schedule_record": True}):
+ s = te.create_schedule([])
+ # size of schedule record is expected to be 1 (vanilla schedule)
+ assert len(s.schedule_record) == 1
+
+ stg = te.compute((), lambda *args: 0, name="empty_op")
+ s = te.create_schedule(stg.op)
+ # size of schedule record is expected to be 1 (vanilla schedule)
+ assert len(s.schedule_record) == 1
+
+
if __name__ == "__main__":
test_loop_dep_reduce()
test_loop_dep_reduce_cache_write()
@@ -640,3 +691,5 @@ if __name__ == "__main__":
test_schedule_compute_inline()
test_local_stage_predicate()
test_local_stage_predicate2()
+ test_schedule_record_gemm()
+ test_schedule_record_misc()