This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f904d4f [AutoScheduler] Fix policy for zero-rank output (#7180)
f904d4f is described below
commit f904d4fe95b16044a8d46edeea0b7b7792e0ef3c
Author: Lianmin Zheng <[email protected]>
AuthorDate: Wed Dec 30 04:18:41 2020 -0800
[AutoScheduler] Fix policy for zero-rank output (#7180)
---
src/auto_scheduler/search_policy/sketch_policy.cc | 15 ++--
src/auto_scheduler/search_policy/sketch_policy.h | 7 ++
src/auto_scheduler/search_policy/utils.h | 7 +-
src/auto_scheduler/transform_step.cc | 35 ++++++---
.../python/unittest/test_auto_scheduler_common.py | 17 +++++
.../unittest/test_auto_scheduler_search_policy.py | 83 ++++++++++++++++------
.../test_auto_scheduler_sketch_generation.py | 16 +++++
7 files changed, 136 insertions(+), 44 deletions(-)
diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc
b/src/auto_scheduler/search_policy/sketch_policy.cc
index e267837..1e20b0f 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy.cc
@@ -78,6 +78,8 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel
program_cost_model,
node->rand_gen = std::mt19937(seed);
node->params = std::move(params);
node->verbose = verbose;
+ node->sample_init_min_pop_ =
+ GetIntParam(node->params,
SketchParamKey::SampleInitPopulation::min_population);
if (init_search_callbacks) {
PrintTitle("Call init-search callbacks", verbose);
@@ -382,8 +384,6 @@ Array<State> SketchPolicyNode::GenerateSketches() {
Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>&
sketches) {
// Use this population as the parallel degree to do sampling
int population = GetIntParam(params,
SketchParamKey::EvolutionarySearch::population);
- // At least we should sample this number of valid programs
- int min_population = GetIntParam(params,
SketchParamKey::SampleInitPopulation::min_population);
auto tic_begin = std::chrono::high_resolution_clock::now();
@@ -397,9 +397,8 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const
Array<State>& sketches
std::unordered_set<std::string> explored_state_strs;
size_t iter = 1;
- size_t target_size = min_population;
size_t unchange_cnt = 0;
- while (out_states.size() < target_size) {
+ while (static_cast<int>(out_states.size()) < sample_init_min_pop_) {
std::vector<State> temp_states(population);
// Sample a batch of states randomly
@@ -458,7 +457,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const
Array<State>& sketches
std::chrono::high_resolution_clock::now() -
tic_begin)
.count();
StdCout(verbose) << "Sample Iter: " << iter << std::fixed <<
std::setprecision(4)
- << "\t#Pop: " << out_states.size() << "\t#Target: " <<
target_size
+ << "\t#Pop: " << out_states.size() << "\t#Target: " <<
sample_init_min_pop_
<< "\tfail_ct: " << fail_ct << "\tTime elapsed: " <<
std::fixed
<< std::setprecision(2) << duration << std::endl;
}
@@ -466,9 +465,9 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const
Array<State>& sketches
if (unchange_cnt == 5) {
// Reduce the target size to avoid too-long time in this phase if no
valid state was found
// in the past iterations
- if (target_size > 1) {
- target_size /= 2;
- StdCout(verbose) << "#Target has been reduced to " << target_size
+ if (sample_init_min_pop_ > 1) {
+ sample_init_min_pop_ /= 2;
+ StdCout(verbose) << "#Target has been reduced to " <<
sample_init_min_pop_
<< " due to too many failures or duplications" <<
std::endl;
}
unchange_cnt = 0;
diff --git a/src/auto_scheduler/search_policy/sketch_policy.h
b/src/auto_scheduler/search_policy/sketch_policy.h
index 3d135d1..4886349 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.h
+++ b/src/auto_scheduler/search_policy/sketch_policy.h
@@ -87,6 +87,8 @@ struct SketchParamKey {
static constexpr const char* disable_change_compute_location =
"disable_change_compute_location";
};
+class SketchPolicy;
+
/*!
* \brief The search policy that searches in a hierarchical search space
defined by sketches.
* The policy randomly samples programs from the space defined by sketches
@@ -166,6 +168,11 @@ class SketchPolicyNode : public SearchPolicyNode {
/*! \brief The cached sketches */
Array<State> sketch_cache_;
+
+ /*! \brief The minimul output population of SampleInitPopulation */
+ int sample_init_min_pop_;
+
+ friend class SketchPolicy;
};
/*!
diff --git a/src/auto_scheduler/search_policy/utils.h
b/src/auto_scheduler/search_policy/utils.h
index d59a6ca..eb2cd69 100644
--- a/src/auto_scheduler/search_policy/utils.h
+++ b/src/auto_scheduler/search_policy/utils.h
@@ -609,12 +609,11 @@ inline State FuseAllOuterSpaceIterators(const State&
state, int stage_id, Iterat
to_fuse.push_back(it);
}
- ICHECK(!to_fuse.empty());
State tmp_s = state;
- if (to_fuse.size() > 1) {
- *fused_iter = tmp_s.fuse(stage_id, to_fuse);
- } else {
+ if (to_fuse.size() == 1) {
*fused_iter = to_fuse[0];
+ } else {
+ *fused_iter = tmp_s.fuse(stage_id, to_fuse);
}
return tmp_s;
}
diff --git a/src/auto_scheduler/transform_step.cc
b/src/auto_scheduler/transform_step.cc
index 5560907..5ba3eee 100755
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -538,15 +538,25 @@ Iterator FuseStepNode::ApplyToState(State* state) const {
Iterator new_it =
Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone,
&orig_iters);
Array<Iterator> new_iters;
- new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin()
+ fused_ids.front());
- new_iters.push_back(new_it);
- new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() +
1,
- stage->iters.end());
+
+ if (fused_ids.empty()) {
+ new_iters.push_back(new_it);
+ } else {
+ new_iters.insert(new_iters.end(), stage->iters.begin(),
+ stage->iters.begin() + fused_ids.front());
+ new_iters.push_back(new_it);
+ new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back()
+ 1,
+ stage->iters.end());
+ }
StateNode* pstate = state->CopyOnWrite();
pstate->stages.Set(stage_id,
Stage(stage->op, stage->op_type, new_iters,
stage->compute_at, stage->attrs));
+ if (fused_ids.empty()) {
+ return new_it;
+ }
+
// Two vectors are used to represent the iterator relation before and after
fuse
// The original iterators in AttachMap will be updated with the new iterators
std::vector<IterKey> from_iters;
@@ -583,9 +593,13 @@ IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>*
stages,
stage.fuse(to_fuse, &fused_axis);
Array<IterVar> new_axes;
- new_axes.insert(new_axes.end(), axes.begin(), axes.begin() +
fused_ids.front());
- new_axes.push_back(fused_axis);
- new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1,
axes.end());
+ if (fused_ids.empty()) {
+ new_axes.push_back(fused_axis);
+ } else {
+ new_axes.insert(new_axes.end(), axes.begin(), axes.begin() +
fused_ids.front());
+ new_axes.push_back(fused_axis);
+ new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1,
axes.end());
+ }
stage_to_axes->Set(stage, std::move(new_axes));
stages->Set(stage_id, std::move(stage));
@@ -683,9 +697,12 @@ void PragmaStepNode::ApplyToSchedule(Array<te::Stage>*
stages,
}
ICHECK_LT(pos, pragma_type.size()) << "max step value not found.";
int value = atoi(pragma_type.c_str() + pos + 1);
- stage.pragma(axes[iter_id], "auto_unroll_max_step", value);
- stage.pragma(axes[iter_id], "unroll_explicit", true);
+ if (iter_id < static_cast<int>(axes.size())) {
+ stage.pragma(axes[iter_id], "auto_unroll_max_step", value);
+ stage.pragma(axes[iter_id], "unroll_explicit", true);
+ }
} else {
+ ICHECK_LT(iter_id, axes.size());
stage.pragma(axes[iter_id], pragma_type);
}
stages->Set(stage_id, std::move(stage));
diff --git a/tests/python/unittest/test_auto_scheduler_common.py
b/tests/python/unittest/test_auto_scheduler_common.py
index a037b68..2f94231 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -146,6 +146,23 @@ def invalid_compute_definition():
@auto_scheduler.register_workload
+def zero_rank_reduce_auto_scheduler_test(N):
+ A = tvm.te.placeholder((N,), name="A")
+ k = tvm.te.reduce_axis((0, N), name="k")
+ B = tvm.te.compute((), lambda: tvm.te.sum(A[k], k), name="B")
+
+ return [A, B]
+
+
+@auto_scheduler.register_workload
+def zero_rank_compute_auto_scheduler_test(N):
+ A = tvm.te.placeholder((N,), name="A")
+ B = tvm.te.compute((), lambda: A[0], name="B")
+
+ return [A, B]
+
+
+@auto_scheduler.register_workload
def conv2d_winograd_nhwc_auto_scheduler_test(
N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1
):
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py
b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 73ce0a1..c96dc63 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -25,8 +25,13 @@ import tempfile
import tvm
import tvm.testing
from tvm import auto_scheduler
+from tvm.auto_scheduler.utils import get_const_tuple
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from test_auto_scheduler_common import (
+ matmul_auto_scheduler_test,
+ zero_rank_compute_auto_scheduler_test,
+ zero_rank_reduce_auto_scheduler_test,
+)
import multiprocessing
@@ -41,21 +46,21 @@ class
CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback):
def search_common(
- workload=matmul_auto_scheduler_test,
+ task=None,
target="llvm",
search_policy="sketch",
- seed=0,
runner="local",
num_measure_trials=100,
cost_model=auto_scheduler.RandomModel(),
init_search_callbacks=None,
):
- print("Test search policy '%s' for '%s'" % (search_policy, target))
+ if task is None:
+ task = auto_scheduler.SearchTask(
+ func=matmul_auto_scheduler_test, args=(64, 64, 64), target=target
+ )
+ target = task.target
- random.seed(seed)
- N = 128
- target = tvm.target.Target(target)
- task = auto_scheduler.SearchTask(func=workload, args=(N, N, N),
target=target)
+ print("Test search policy '%s' for '%s'" % (search_policy, target))
with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name
@@ -72,6 +77,7 @@ def search_common(
else:
raise ValueError("Invalid policy: " + search_policy)
+ # Tune
tuning_options = auto_scheduler.TuningOptions(
num_measure_trials=num_measure_trials,
num_measures_per_round=2,
@@ -80,33 +86,47 @@ def search_common(
measure_callbacks=[auto_scheduler.RecordToFile(log_file),
CustomMeasureCallback()],
)
task.tune(tuning_options=tuning_options, search_policy=search_policy)
+
+ # Compile with the best schedule
sch, args = task.apply_best(log_file)
+ mod = tvm.build(sch, args, target)
- try:
- mod = tvm.build(sch, args, target)
+ # Compile with naive schedule for correctness check
+ sch, args =
task.compute_dag.apply_steps_from_state(task.compute_dag.init_state)
+ mod_ref = tvm.build(sch, args, "llvm")
- ctx = tvm.context(str(target), 0)
- dtype = task.compute_dag.tensors[0].dtype
- a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
- b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
- c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
- mod(a, b, c)
- tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(),
b.asnumpy()), rtol=1e-5)
- except Exception:
- raise Exception("Error encountered with seed: %d" % (seed))
+ ctx = tvm.context(str(target), 0)
+ np_arrays =
[np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype) for x in args]
+
+ tvm_arrays = [tvm.nd.array(x, ctx) for x in np_arrays]
+ mod(*tvm_arrays)
+ actual = [x.asnumpy() for x in tvm_arrays]
+
+ tvm_arrays = [tvm.nd.array(x) for x in np_arrays]
+ mod_ref(*tvm_arrays)
+ expected = [x.asnumpy() for x in tvm_arrays]
+
+ for x, y in zip(actual, expected):
+ tvm.testing.assert_allclose(x, y, rtol=1e-5)
@tvm.testing.requires_llvm
-def test_workload_registry_search_basic():
+def test_workload_registry_empty_policy():
search_common(search_policy="empty", num_measure_trials=2)
+ N = 64
+ target = "llvm"
search_common(
- workload="matmul_auto_scheduler_test",
+ task=auto_scheduler.SearchTask(
+ func="matmul_auto_scheduler_test", args=(N, N, N), target=target
+ ),
num_measure_trials=2,
search_policy="empty",
)
search_common(
- workload="matmul_auto_scheduler_test_rename_1",
+ task=auto_scheduler.SearchTask(
+ func="matmul_auto_scheduler_test_rename_1", args=(N, N, N),
target=target
+ ),
num_measure_trials=2,
search_policy="empty",
)
@@ -147,10 +167,27 @@ def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
search_common(target="cuda", runner=measure_ctx.runner,
cost_model=auto_scheduler.XGBModel())
[email protected]_llvm
[email protected]_cuda
+def test_sketch_search_policy_zero_rank():
+ measure_ctx = auto_scheduler.LocalRPCMeasureContext()
+ for target in ["llvm", "cuda"]:
+ task = auto_scheduler.SearchTask(
+ func=zero_rank_compute_auto_scheduler_test, args=(10,),
target=target
+ )
+ search_common(task, runner=measure_ctx.runner)
+
+ task = auto_scheduler.SearchTask(
+ func=zero_rank_reduce_auto_scheduler_test, args=(10,),
target=target
+ )
+ search_common(task, runner=measure_ctx.runner)
+
+
if __name__ == "__main__":
- test_workload_registry_search_basic()
+ test_workload_registry_empty_policy()
test_sketch_search_policy_basic()
test_sketch_search_policy_basic_spawn()
test_sketch_search_policy_xgbmodel()
test_sketch_search_policy_cuda_rpc_runner()
test_sketch_search_policy_cuda_xgbmodel_rpc_runner()
+ test_sketch_search_policy_zero_rank()
diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
index 74d5729..ddff6dd 100644
--- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
+++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
@@ -32,6 +32,7 @@ from test_auto_scheduler_common import (
softmax_nm_auto_scheduler_test,
softmax_abcd_auto_scheduler_test,
conv2d_winograd_nhwc_auto_scheduler_test,
+ zero_rank_reduce_auto_scheduler_test,
)
@@ -252,6 +253,12 @@ def test_cpu_conv2d_winograd_sketch():
assert sketches[1] != sketches[2]
+def test_cpu_zero_rank_sketch():
+ sketches = generate_sketches(zero_rank_reduce_auto_scheduler_test, (128,),
"llvm")
+ """ 2 rfactor sketches + 1 multi-level tiling sketches """
+ assert len(sketches) == 3
+
+
@tvm.testing.requires_cuda
def test_cuda_matmul_sketch():
sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512),
"cuda")
@@ -385,6 +392,13 @@ def test_cuda_conv2d_winograd_sketch():
assert_is_not_tiled(sketches[0].stages[12])
[email protected]_cuda
+def test_cuda_zero_rank_sketch():
+ sketches = generate_sketches(zero_rank_reduce_auto_scheduler_test, (128,),
"cuda")
+ """ 1 cross thread reuction sketch + 1 multi-level tiling sketch """
+ assert len(sketches) == 2
+
+
if __name__ == "__main__":
test_cpu_matmul_sketch()
test_cpu_conv2d_bn_relu_sketch()
@@ -392,9 +406,11 @@ if __name__ == "__main__":
test_cpu_min_sketch()
test_cpu_softmax_sketch()
test_cpu_conv2d_winograd_sketch()
+ test_cpu_zero_rank_sketch()
test_cuda_matmul_sketch()
test_cuda_conv2d_bn_relu_sketch()
test_cuda_max_pool2d_sketch()
test_cuda_min_sketch()
test_cuda_softmax_sketch()
test_cuda_conv2d_winograd_sketch()
+ test_cuda_zero_rank_sketch()