This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f904d4f  [AutoScheduler] Fix policy for zero-rank output (#7180)
f904d4f is described below

commit f904d4fe95b16044a8d46edeea0b7b7792e0ef3c
Author: Lianmin Zheng <[email protected]>
AuthorDate: Wed Dec 30 04:18:41 2020 -0800

    [AutoScheduler] Fix policy for zero-rank output (#7180)
---
 src/auto_scheduler/search_policy/sketch_policy.cc  | 15 ++--
 src/auto_scheduler/search_policy/sketch_policy.h   |  7 ++
 src/auto_scheduler/search_policy/utils.h           |  7 +-
 src/auto_scheduler/transform_step.cc               | 35 ++++++---
 .../python/unittest/test_auto_scheduler_common.py  | 17 +++++
 .../unittest/test_auto_scheduler_search_policy.py  | 83 ++++++++++++++++------
 .../test_auto_scheduler_sketch_generation.py       | 16 +++++
 7 files changed, 136 insertions(+), 44 deletions(-)

diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc 
b/src/auto_scheduler/search_policy/sketch_policy.cc
index e267837..1e20b0f 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.cc
+++ b/src/auto_scheduler/search_policy/sketch_policy.cc
@@ -78,6 +78,8 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel 
program_cost_model,
   node->rand_gen = std::mt19937(seed);
   node->params = std::move(params);
   node->verbose = verbose;
+  node->sample_init_min_pop_ =
+      GetIntParam(node->params, 
SketchParamKey::SampleInitPopulation::min_population);
 
   if (init_search_callbacks) {
     PrintTitle("Call init-search callbacks", verbose);
@@ -382,8 +384,6 @@ Array<State> SketchPolicyNode::GenerateSketches() {
 Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& 
sketches) {
   // Use this population as the parallel degree to do sampling
   int population = GetIntParam(params, 
SketchParamKey::EvolutionarySearch::population);
-  // At least we should sample this number of valid programs
-  int min_population = GetIntParam(params, 
SketchParamKey::SampleInitPopulation::min_population);
 
   auto tic_begin = std::chrono::high_resolution_clock::now();
 
@@ -397,9 +397,8 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const 
Array<State>& sketches
 
   std::unordered_set<std::string> explored_state_strs;
   size_t iter = 1;
-  size_t target_size = min_population;
   size_t unchange_cnt = 0;
-  while (out_states.size() < target_size) {
+  while (static_cast<int>(out_states.size()) < sample_init_min_pop_) {
     std::vector<State> temp_states(population);
 
     // Sample a batch of states randomly
@@ -458,7 +457,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const 
Array<State>& sketches
                             std::chrono::high_resolution_clock::now() - 
tic_begin)
                             .count();
       StdCout(verbose) << "Sample Iter: " << iter << std::fixed << 
std::setprecision(4)
-                       << "\t#Pop: " << out_states.size() << "\t#Target: " << 
target_size
+                       << "\t#Pop: " << out_states.size() << "\t#Target: " << 
sample_init_min_pop_
                        << "\tfail_ct: " << fail_ct << "\tTime elapsed: " << 
std::fixed
                        << std::setprecision(2) << duration << std::endl;
     }
@@ -466,9 +465,9 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const 
Array<State>& sketches
     if (unchange_cnt == 5) {
       // Reduce the target size to avoid too-long time in this phase if no 
valid state was found
       // in the past iterations
-      if (target_size > 1) {
-        target_size /= 2;
-        StdCout(verbose) << "#Target has been reduced to " << target_size
+      if (sample_init_min_pop_ > 1) {
+        sample_init_min_pop_ /= 2;
+        StdCout(verbose) << "#Target has been reduced to " << 
sample_init_min_pop_
                          << " due to too many failures or duplications" << 
std::endl;
       }
       unchange_cnt = 0;
diff --git a/src/auto_scheduler/search_policy/sketch_policy.h 
b/src/auto_scheduler/search_policy/sketch_policy.h
index 3d135d1..4886349 100644
--- a/src/auto_scheduler/search_policy/sketch_policy.h
+++ b/src/auto_scheduler/search_policy/sketch_policy.h
@@ -87,6 +87,8 @@ struct SketchParamKey {
   static constexpr const char* disable_change_compute_location = 
"disable_change_compute_location";
 };
 
+class SketchPolicy;
+
 /*!
  * \brief The search policy that searches in a hierarchical search space 
defined by sketches.
  * The policy randomly samples programs from the space defined by sketches
@@ -166,6 +168,11 @@ class SketchPolicyNode : public SearchPolicyNode {
 
   /*! \brief The cached sketches */
   Array<State> sketch_cache_;
+
+  /*! \brief The minimul output population of SampleInitPopulation */
+  int sample_init_min_pop_;
+
+  friend class SketchPolicy;
 };
 
 /*!
diff --git a/src/auto_scheduler/search_policy/utils.h 
b/src/auto_scheduler/search_policy/utils.h
index d59a6ca..eb2cd69 100644
--- a/src/auto_scheduler/search_policy/utils.h
+++ b/src/auto_scheduler/search_policy/utils.h
@@ -609,12 +609,11 @@ inline State FuseAllOuterSpaceIterators(const State& 
state, int stage_id, Iterat
     to_fuse.push_back(it);
   }
 
-  ICHECK(!to_fuse.empty());
   State tmp_s = state;
-  if (to_fuse.size() > 1) {
-    *fused_iter = tmp_s.fuse(stage_id, to_fuse);
-  } else {
+  if (to_fuse.size() == 1) {
     *fused_iter = to_fuse[0];
+  } else {
+    *fused_iter = tmp_s.fuse(stage_id, to_fuse);
   }
   return tmp_s;
 }
diff --git a/src/auto_scheduler/transform_step.cc 
b/src/auto_scheduler/transform_step.cc
index 5560907..5ba3eee 100755
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -538,15 +538,25 @@ Iterator FuseStepNode::ApplyToState(State* state) const {
   Iterator new_it =
       Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone, 
&orig_iters);
   Array<Iterator> new_iters;
-  new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() 
+ fused_ids.front());
-  new_iters.push_back(new_it);
-  new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() + 
1,
-                   stage->iters.end());
+
+  if (fused_ids.empty()) {
+    new_iters.push_back(new_it);
+  } else {
+    new_iters.insert(new_iters.end(), stage->iters.begin(),
+                     stage->iters.begin() + fused_ids.front());
+    new_iters.push_back(new_it);
+    new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() 
+ 1,
+                     stage->iters.end());
+  }
 
   StateNode* pstate = state->CopyOnWrite();
   pstate->stages.Set(stage_id,
                      Stage(stage->op, stage->op_type, new_iters, 
stage->compute_at, stage->attrs));
 
+  if (fused_ids.empty()) {
+    return new_it;
+  }
+
   // Two vectors are used to represent the iterator relation before and after 
fuse
   // The original iterators in AttachMap will be updated with the new iterators
   std::vector<IterKey> from_iters;
@@ -583,9 +593,13 @@ IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* 
stages,
   stage.fuse(to_fuse, &fused_axis);
 
   Array<IterVar> new_axes;
-  new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + 
fused_ids.front());
-  new_axes.push_back(fused_axis);
-  new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, 
axes.end());
+  if (fused_ids.empty()) {
+    new_axes.push_back(fused_axis);
+  } else {
+    new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + 
fused_ids.front());
+    new_axes.push_back(fused_axis);
+    new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, 
axes.end());
+  }
 
   stage_to_axes->Set(stage, std::move(new_axes));
   stages->Set(stage_id, std::move(stage));
@@ -683,9 +697,12 @@ void PragmaStepNode::ApplyToSchedule(Array<te::Stage>* 
stages,
     }
     ICHECK_LT(pos, pragma_type.size()) << "max step value not found.";
     int value = atoi(pragma_type.c_str() + pos + 1);
-    stage.pragma(axes[iter_id], "auto_unroll_max_step", value);
-    stage.pragma(axes[iter_id], "unroll_explicit", true);
+    if (iter_id < static_cast<int>(axes.size())) {
+      stage.pragma(axes[iter_id], "auto_unroll_max_step", value);
+      stage.pragma(axes[iter_id], "unroll_explicit", true);
+    }
   } else {
+    ICHECK_LT(iter_id, axes.size());
     stage.pragma(axes[iter_id], pragma_type);
   }
   stages->Set(stage_id, std::move(stage));
diff --git a/tests/python/unittest/test_auto_scheduler_common.py 
b/tests/python/unittest/test_auto_scheduler_common.py
index a037b68..2f94231 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -146,6 +146,23 @@ def invalid_compute_definition():
 
 
 @auto_scheduler.register_workload
+def zero_rank_reduce_auto_scheduler_test(N):
+    A = tvm.te.placeholder((N,), name="A")
+    k = tvm.te.reduce_axis((0, N), name="k")
+    B = tvm.te.compute((), lambda: tvm.te.sum(A[k], k), name="B")
+
+    return [A, B]
+
+
+@auto_scheduler.register_workload
+def zero_rank_compute_auto_scheduler_test(N):
+    A = tvm.te.placeholder((N,), name="A")
+    B = tvm.te.compute((), lambda: A[0], name="B")
+
+    return [A, B]
+
+
+@auto_scheduler.register_workload
 def conv2d_winograd_nhwc_auto_scheduler_test(
     N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1
 ):
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py 
b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 73ce0a1..c96dc63 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -25,8 +25,13 @@ import tempfile
 import tvm
 import tvm.testing
 from tvm import auto_scheduler
+from tvm.auto_scheduler.utils import get_const_tuple
 
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from test_auto_scheduler_common import (
+    matmul_auto_scheduler_test,
+    zero_rank_compute_auto_scheduler_test,
+    zero_rank_reduce_auto_scheduler_test,
+)
 import multiprocessing
 
 
@@ -41,21 +46,21 @@ class 
CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback):
 
 
 def search_common(
-    workload=matmul_auto_scheduler_test,
+    task=None,
     target="llvm",
     search_policy="sketch",
-    seed=0,
     runner="local",
     num_measure_trials=100,
     cost_model=auto_scheduler.RandomModel(),
     init_search_callbacks=None,
 ):
-    print("Test search policy '%s' for '%s'" % (search_policy, target))
+    if task is None:
+        task = auto_scheduler.SearchTask(
+            func=matmul_auto_scheduler_test, args=(64, 64, 64), target=target
+        )
+    target = task.target
 
-    random.seed(seed)
-    N = 128
-    target = tvm.target.Target(target)
-    task = auto_scheduler.SearchTask(func=workload, args=(N, N, N), 
target=target)
+    print("Test search policy '%s' for '%s'" % (search_policy, target))
 
     with tempfile.NamedTemporaryFile() as fp:
         log_file = fp.name
@@ -72,6 +77,7 @@ def search_common(
         else:
             raise ValueError("Invalid policy: " + search_policy)
 
+        # Tune
         tuning_options = auto_scheduler.TuningOptions(
             num_measure_trials=num_measure_trials,
             num_measures_per_round=2,
@@ -80,33 +86,47 @@ def search_common(
             measure_callbacks=[auto_scheduler.RecordToFile(log_file), 
CustomMeasureCallback()],
         )
         task.tune(tuning_options=tuning_options, search_policy=search_policy)
+
+        # Compile with the best schedule
         sch, args = task.apply_best(log_file)
+        mod = tvm.build(sch, args, target)
 
-        try:
-            mod = tvm.build(sch, args, target)
+        # Compile with naive schedule for correctness check
+        sch, args = 
task.compute_dag.apply_steps_from_state(task.compute_dag.init_state)
+        mod_ref = tvm.build(sch, args, "llvm")
 
-            ctx = tvm.context(str(target), 0)
-            dtype = task.compute_dag.tensors[0].dtype
-            a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
-            b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx)
-            c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
-            mod(a, b, c)
-            tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), 
b.asnumpy()), rtol=1e-5)
-        except Exception:
-            raise Exception("Error encountered with seed: %d" % (seed))
+        ctx = tvm.context(str(target), 0)
+        np_arrays = 
[np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype) for x in args]
+
+        tvm_arrays = [tvm.nd.array(x, ctx) for x in np_arrays]
+        mod(*tvm_arrays)
+        actual = [x.asnumpy() for x in tvm_arrays]
+
+        tvm_arrays = [tvm.nd.array(x) for x in np_arrays]
+        mod_ref(*tvm_arrays)
+        expected = [x.asnumpy() for x in tvm_arrays]
+
+        for x, y in zip(actual, expected):
+            tvm.testing.assert_allclose(x, y, rtol=1e-5)
 
 
 @tvm.testing.requires_llvm
-def test_workload_registry_search_basic():
+def test_workload_registry_empty_policy():
     search_common(search_policy="empty", num_measure_trials=2)
 
+    N = 64
+    target = "llvm"
     search_common(
-        workload="matmul_auto_scheduler_test",
+        task=auto_scheduler.SearchTask(
+            func="matmul_auto_scheduler_test", args=(N, N, N), target=target
+        ),
         num_measure_trials=2,
         search_policy="empty",
     )
     search_common(
-        workload="matmul_auto_scheduler_test_rename_1",
+        task=auto_scheduler.SearchTask(
+            func="matmul_auto_scheduler_test_rename_1", args=(N, N, N), 
target=target
+        ),
         num_measure_trials=2,
         search_policy="empty",
     )
@@ -147,10 +167,27 @@ def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
     search_common(target="cuda", runner=measure_ctx.runner, 
cost_model=auto_scheduler.XGBModel())
 
 
[email protected]_llvm
[email protected]_cuda
+def test_sketch_search_policy_zero_rank():
+    measure_ctx = auto_scheduler.LocalRPCMeasureContext()
+    for target in ["llvm", "cuda"]:
+        task = auto_scheduler.SearchTask(
+            func=zero_rank_compute_auto_scheduler_test, args=(10,), 
target=target
+        )
+        search_common(task, runner=measure_ctx.runner)
+
+        task = auto_scheduler.SearchTask(
+            func=zero_rank_reduce_auto_scheduler_test, args=(10,), 
target=target
+        )
+        search_common(task, runner=measure_ctx.runner)
+
+
 if __name__ == "__main__":
-    test_workload_registry_search_basic()
+    test_workload_registry_empty_policy()
     test_sketch_search_policy_basic()
     test_sketch_search_policy_basic_spawn()
     test_sketch_search_policy_xgbmodel()
     test_sketch_search_policy_cuda_rpc_runner()
     test_sketch_search_policy_cuda_xgbmodel_rpc_runner()
+    test_sketch_search_policy_zero_rank()
diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py 
b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
index 74d5729..ddff6dd 100644
--- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
+++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
@@ -32,6 +32,7 @@ from test_auto_scheduler_common import (
     softmax_nm_auto_scheduler_test,
     softmax_abcd_auto_scheduler_test,
     conv2d_winograd_nhwc_auto_scheduler_test,
+    zero_rank_reduce_auto_scheduler_test,
 )
 
 
@@ -252,6 +253,12 @@ def test_cpu_conv2d_winograd_sketch():
     assert sketches[1] != sketches[2]
 
 
+def test_cpu_zero_rank_sketch():
+    sketches = generate_sketches(zero_rank_reduce_auto_scheduler_test, (128,), 
"llvm")
+    """ 2 rfactor sketches + 1 multi-level tiling sketches """
+    assert len(sketches) == 3
+
+
 @tvm.testing.requires_cuda
 def test_cuda_matmul_sketch():
     sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 
"cuda")
@@ -385,6 +392,13 @@ def test_cuda_conv2d_winograd_sketch():
     assert_is_not_tiled(sketches[0].stages[12])
 
 
[email protected]_cuda
+def test_cuda_zero_rank_sketch():
+    sketches = generate_sketches(zero_rank_reduce_auto_scheduler_test, (128,), 
"cuda")
+    """ 1 cross thread reuction sketch + 1 multi-level tiling sketch """
+    assert len(sketches) == 2
+
+
 if __name__ == "__main__":
     test_cpu_matmul_sketch()
     test_cpu_conv2d_bn_relu_sketch()
@@ -392,9 +406,11 @@ if __name__ == "__main__":
     test_cpu_min_sketch()
     test_cpu_softmax_sketch()
     test_cpu_conv2d_winograd_sketch()
+    test_cpu_zero_rank_sketch()
     test_cuda_matmul_sketch()
     test_cuda_conv2d_bn_relu_sketch()
     test_cuda_max_pool2d_sketch()
     test_cuda_min_sketch()
     test_cuda_softmax_sketch()
     test_cuda_conv2d_winograd_sketch()
+    test_cuda_zero_rank_sketch()

Reply via email to