This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4dd7f68 [TIR] Use PopenPool instead of multiprocessing.pool (#8492)
4dd7f68 is described below
commit 4dd7f6806f05bbc9c33d68d68493113534a34b12
Author: Yuanjing Shi <[email protected]>
AuthorDate: Thu Aug 12 13:35:04 2021 -0700
[TIR] Use PopenPool instead of multiprocessing.pool (#8492)
Co-authored-by: Wuwei Lin <[email protected]>
---
python/tvm/auto_scheduler/measure.py | 216 ++++++++++++---------
python/tvm/auto_scheduler/utils.py | 43 +---
python/tvm/autotvm/record.py | 4 +-
python/tvm/autotvm/utils.py | 4 +-
python/tvm/contrib/popen_pool.py | 9 +-
python/tvm/testing/__init__.py | 34 ++++
python/tvm/testing/_ffi_api.py | 21 ++
.../tvm/testing/auto_scheduler.py | 2 +-
python/tvm/{testing.py => testing/utils.py} | 3 -
.../unittest/test_auto_scheduler_compute_dag.py | 2 +-
.../unittest/test_auto_scheduler_cost_model.py | 2 +-
.../test_auto_scheduler_evolutionary_search.py | 2 +-
.../python/unittest/test_auto_scheduler_feature.py | 2 +-
.../unittest/test_auto_scheduler_layout_rewrite.py | 2 +-
.../unittest/test_auto_scheduler_loop_state.py | 2 +-
.../python/unittest/test_auto_scheduler_measure.py | 2 +-
.../unittest/test_auto_scheduler_search_policy.py | 2 +-
.../unittest/test_auto_scheduler_search_task.py | 2 +-
.../test_auto_scheduler_sketch_generation.py | 2 +-
.../unittest/test_auto_scheduler_task_scheduler.py | 2 +-
20 files changed, 205 insertions(+), 153 deletions(-)
diff --git a/python/tvm/auto_scheduler/measure.py
b/python/tvm/auto_scheduler/measure.py
index 8d76260..a202e83 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -44,6 +44,7 @@ from tvm.driver import build_module
from tvm.ir import transform
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import tar, ndk
+from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind
from tvm.target import Target
@@ -599,7 +600,7 @@ class MeasureErrorNo(object):
UNKNOWN_ERROR = 8 # Unknown error
-def _timed_func(inp_serialized, build_func, verbose):
+def _local_build_worker(inp_serialized, build_func, verbose):
tic = time.time()
inp = MeasureInput.deserialize(inp_serialized)
task = inp.task
@@ -664,15 +665,13 @@ def local_build_worker(args):
)
build_func = BuildFunc.build_func
- res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func,
verbose))
- if isinstance(res, TimeoutError):
- if verbose >= 1:
- print(".T", end="", flush=True) # Build timeout
- res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
- elif isinstance(res, Exception):
+ try:
+ res = _local_build_worker(inp, build_func, verbose)
+ # pylint: disable=broad-except
+ except Exception:
if verbose >= 1:
print(".E", end="", flush=True) # Build error
- res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout
+ res = None, [], MeasureErrorNo.COMPILE_HOST, make_traceback_info(),
timeout
return res
@@ -701,9 +700,8 @@ def local_builder_build(inputs, timeout, n_parallel,
build_func="default", verbo
res : List[BuildResult]
The build results of these MeasureInputs.
"""
- # This pool is not doing computationally intensive work, so we can use
threads
- pool = multiprocessing.pool.ThreadPool(n_parallel)
- tuple_res = pool.map(
+ executor = PopenPoolExecutor(n_parallel, timeout)
+ tuple_res = executor.map_with_error_catching(
local_build_worker,
[
(
@@ -715,13 +713,16 @@ def local_builder_build(inputs, timeout, n_parallel,
build_func="default", verbo
for i in inputs
],
)
- pool.terminate()
- pool.join()
- del pool
results = []
for res in tuple_res:
- results.append(BuildResult(*res))
+ if res.status == StatusKind.COMPLETE:
+ results.append(BuildResult(*res.value))
+ else:
+ assert res.status == StatusKind.TIMEOUT
+ if verbose >= 1:
+ print(".T", end="", flush=True) # Build timeout
+ results.append(BuildResult(None, [], MeasureErrorNo.BUILD_TIMEOUT,
None, timeout))
return results
@@ -817,9 +818,58 @@ def prepare_input_map(args):
return tensor_input_map
+def prepare_runner_args(inp, build_res):
+ """This function prepares the pre-defined arguments in
`TASK_INPUT_BUFFER_TABLE` for local/rpc
+ runner in main process
+
+ Parameters
+ ----------
+ inp : MeasureInput
+ Measure input to be measured.
+
+ build_res : BuildResult
+ Build result to be measured.
+
+ Returns
+ -------
+ List[Optional[numpy.ndarray]] :
+ List of arguments for running the program. If the argument does not
have a pre-defined input
+ buffer, None is added to the list as a placeholder.
+
+ """
+ # pylint: disable=import-outside-toplevel
+ from .search_task import get_task_input_buffer # lazily import to avoid
recursive dependency
+
+ task_input_names = inp.task.task_input_names
+ tensor_input_map = prepare_input_map(build_res.args)
+ if not task_input_names:
+ tensor_input_map = {}
+ args = []
+ task_inputs_count = 0
+ for arg in build_res.args:
+ if arg in tensor_input_map:
+ tensor_name = tensor_input_map[arg]
+ if tensor_name in task_input_names:
+ task_input_buffer =
get_task_input_buffer(inp.task.workload_key, tensor_name)
+ # convert tvm.NDArray to picklable numpy.ndarray
+ args.append(task_input_buffer.numpy())
+ task_inputs_count += 1
+ else:
+ raise ValueError(
+ "%s not found in task_inputs, " % (tensor_name)
+ + "should provide with `SearchTask(...,
task_inputs={...})`"
+ )
+ else:
+ args.append(None)
+ if task_inputs_count != len(task_input_names):
+ raise RuntimeError("task_inputs not fully matched, check if there's
any unexpected error")
+ return args
+
+
def _timed_eval_func(
inp_serialized,
build_res,
+ args,
number,
repeat,
min_repeat_ms,
@@ -827,11 +877,7 @@ def _timed_eval_func(
enable_cpu_cache_flush,
verbose,
):
- # pylint: disable=import-outside-toplevel
- from .search_task import get_task_input_buffer # lazily import to avoid
recursive dependency
-
inp = MeasureInput.deserialize(inp_serialized)
- task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
@@ -862,33 +908,18 @@ def _timed_eval_func(
try:
random_fill =
tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the
config.cmake"
-
- tensor_input_map = prepare_input_map(build_res.args) if
task_input_names else {}
- args = []
- task_inputs_count = 0
- for arg in build_res.args:
- if arg in tensor_input_map:
- tensor_name = tensor_input_map[arg]
- if tensor_name in task_input_names:
- args.append(
- ndarray.array(
- get_task_input_buffer(inp.task.workload_key,
tensor_name), dev
- )
- )
- task_inputs_count += 1
- else:
- raise ValueError(
- "%s not found in task_inputs, " % (tensor_name)
- + "should provide with `SearchTask(...,
task_inputs={...})`"
- )
- else:
- empty_array = ndarray.empty(get_const_tuple(arg.shape),
arg.dtype, dev)
+ assert len(args) == len(build_res.args)
+ # pylint: disable=consider-using-enumerate
+ for idx in range(len(args)):
+ if args[idx] is None:
+ build_res_arg = build_res.args[idx]
+ empty_array = ndarray.empty(
+ get_const_tuple(build_res_arg.shape),
build_res_arg.dtype, dev
+ )
random_fill(empty_array)
- args.append(empty_array)
- if task_inputs_count != len(task_input_names):
- raise RuntimeError(
- "task_inputs not fully matched, check if there's any
unexpected error"
- )
+ args[idx] = empty_array
+ else:
+ args[idx] = ndarray.array(args[idx], dev)
dev.sync()
costs = time_f(*args).results
# pylint: disable=broad-except
@@ -968,6 +999,7 @@ def local_run(
measure_results = []
assert len(inputs) == len(build_results), "Measure input size should be
equal to build results"
+ worker = PopenWorker()
for inp, build_res in zip(inputs, build_results):
if build_res.error_no != 0:
res = (
@@ -978,12 +1010,15 @@ def local_run(
time.time(),
)
else:
+ args = prepare_runner_args(inp, build_res)
res = call_func_with_timeout(
+ worker,
timeout,
_timed_eval_func,
args=(
inp.serialize(),
build_res,
+ args,
number,
repeat,
min_repeat_ms,
@@ -991,7 +1026,6 @@ def local_run(
enable_cpu_cache_flush,
verbose,
),
- add_thread_wrapper=True,
)
if isinstance(res, TimeoutError):
if verbose >= 1:
@@ -1022,9 +1056,10 @@ def local_run(
return measure_results
-def _timed_rpc_run(
+def _rpc_run(
inp_serialized,
build_res,
+ args,
key,
host,
port,
@@ -1037,11 +1072,7 @@ def _timed_rpc_run(
enable_cpu_cache_flush,
verbose,
):
- # pylint: disable=import-outside-toplevel
- from .search_task import get_task_input_buffer # lazily import to avoid
recursive dependency
-
inp = MeasureInput.deserialize(inp_serialized)
- task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
@@ -1080,32 +1111,18 @@ def _timed_rpc_run(
random_fill
), "Please make sure USE_RANDOM is ON in the config.cmake on the
remote devices"
- tensor_input_map = prepare_input_map(build_res.args) if
task_input_names else {}
- args = []
- task_inputs_count = 0
- for arg in build_res.args:
- if arg in tensor_input_map:
- tensor_name = tensor_input_map[arg]
- if tensor_name in task_input_names:
- args.append(
- ndarray.array(
- get_task_input_buffer(inp.task.workload_key,
tensor_name), dev
- )
- )
- task_inputs_count += 1
- else:
- raise ValueError(
- "%s not found in task_inputs, " % (tensor_name)
- + "should provide with `SearchTask(...,
task_inputs={...})`"
- )
- else:
- empty_array = ndarray.empty(get_const_tuple(arg.shape),
arg.dtype, dev)
+ assert len(args) == len(build_res.args)
+ # pylint: disable=consider-using-enumerate
+ for idx in range(len(args)):
+ if args[idx] is None:
+ build_res_arg = build_res.args[idx]
+ empty_array = ndarray.empty(
+ get_const_tuple(build_res_arg.shape),
build_res_arg.dtype, dev
+ )
random_fill(empty_array)
- args.append(empty_array)
- if task_inputs_count != len(task_input_names):
- logger.warning(
- "task_inputs not fully matched, check if there's any
unexpected error"
- )
+ args[idx] = empty_array
+ else:
+ args[idx] = ndarray.array(args[idx], dev)
dev.sync()
# First run for check that the kernel is correct
@@ -1152,7 +1169,7 @@ def _rpc_run_worker(args):
res : MeasureResult
The measure result of this Runner thread.
"""
- _, build_res, _, _, _, _, timeout, _, _, _, _, _, verbose = args
+ _, build_res, _, _, _, _, _, timeout, _, _, _, _, _, verbose = args
if build_res.error_no != MeasureErrorNo.NO_ERROR:
return (
(MAX_FLOAT,),
@@ -1162,24 +1179,16 @@ def _rpc_run_worker(args):
time.time(),
)
- res = call_func_with_timeout(timeout, _timed_rpc_run, args=args)
- if isinstance(res, TimeoutError):
- if verbose >= 1:
- print("*T", end="") # Run timeout
- res = (
- (MAX_FLOAT,),
- MeasureErrorNo.RUN_TIMEOUT,
- None,
- build_res.time_cost + timeout,
- time.time(),
- )
- elif isinstance(res, Exception):
+ try:
+ res = _rpc_run(*args)
+ # pylint: disable=broad-except
+ except Exception:
if verbose >= 1:
print("*E", end="") # Run error
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUNTIME_DEVICE,
- str(res),
+ make_traceback_info(),
build_res.time_cost + timeout,
time.time(),
)
@@ -1259,13 +1268,14 @@ def rpc_runner_run(
"""
assert len(inputs) == len(build_results), "Measure input size should be
equal to build results"
# This pool is not doing computationally intensive work, so we can use
threads
- pool = multiprocessing.pool.ThreadPool(n_parallel)
- tuple_res = pool.map(
+ executor = PopenPoolExecutor(n_parallel)
+ tuple_res = executor.map_with_error_catching(
_rpc_run_worker,
[
(
inp.serialize(),
build_res,
+ prepare_runner_args(inp, build_res),
key,
host,
port,
@@ -1281,13 +1291,25 @@ def rpc_runner_run(
for inp, build_res in zip(inputs, build_results)
],
)
- pool.terminate()
- pool.join()
- del pool
results = []
- for res in tuple_res:
- results.append(MeasureResult(*res))
+ for i, res in enumerate(tuple_res):
+ if res.status == StatusKind.COMPLETE:
+ results.append(MeasureResult(*res.value))
+ else:
+ assert res.status == StatusKind.TIMEOUT
+ if verbose >= 1:
+ print("*T", end="") # Run timeout
+ build_res = build_results[i]
+ results.append(
+ MeasureResult(
+ (MAX_FLOAT,),
+ MeasureErrorNo.RUN_TIMEOUT,
+ None,
+ build_res.time_cost + timeout,
+ time.time(),
+ )
+ )
if verbose >= 1:
print("")
diff --git a/python/tvm/auto_scheduler/utils.py
b/python/tvm/auto_scheduler/utils.py
index 1c03491..9919bcb 100644
--- a/python/tvm/auto_scheduler/utils.py
+++ b/python/tvm/auto_scheduler/utils.py
@@ -20,9 +20,6 @@
from typing import Hashable
import json
-import multiprocessing
-import multiprocessing.pool
-import queue
import signal
import threading
import traceback
@@ -289,41 +286,15 @@ def call_func_with_thread(func, args, kwargs):
return res[0]
-def _func_wrapper(que, func, args, kwargs, add_thread_wrapper):
- """Call function and return the result over the queue."""
- try:
- if add_thread_wrapper:
- # Add a new layer of threadinng to avoid the conflict between
- # python's multiprocessing and tvm's thread pool.
- res = call_func_with_thread(func, args, kwargs)
- else:
- res = func(*args, **kwargs)
- que.put(res)
- except Exception: # pylint: disable=broad-except
- que.put(Exception(make_traceback_info()))
-
-
-def call_func_with_timeout(timeout, func, args=(), kwargs=None,
add_thread_wrapper=False):
+def call_func_with_timeout(
+ worker, timeout, func, args=(), kwargs=None
+): # pylint: disable=unused-argument
"""Call a function with timeout"""
- que = multiprocessing.Queue(2)
- process = multiprocessing.Process(
- target=_func_wrapper, args=(que, func, args, kwargs or {},
add_thread_wrapper)
- )
- process.start()
-
+ worker.send(func, args, kwargs, timeout)
try:
- res = que.get(timeout=timeout)
- except queue.Empty:
- res = TimeoutError()
-
- # clean queue and process
- kill_child_processes(process.pid)
- process.terminate()
- process.join()
- que.close()
- que.join_thread()
- del process
- del que
+ res = worker.recv()
+ except Exception: # pylint: disable=broad-except
+ res = Exception(make_traceback_info())
return res
diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py
index 4f11aea..8145563 100644
--- a/python/tvm/autotvm/record.py
+++ b/python/tvm/autotvm/record.py
@@ -21,7 +21,6 @@
import argparse
import base64
import logging
-import multiprocessing
import pickle
import json
import time
@@ -32,6 +31,7 @@ import numpy as np
from .. import build, lower
from ..target import Target
+from ..contrib import popen_pool
from .. import __version__
from . import task
from .task import ConfigEntity, ApplyHistoryBest
@@ -230,7 +230,7 @@ def split_workload(in_file, clean=True):
lines = list(open(in_file).readlines())
logger.info("start converting...")
- pool = multiprocessing.Pool()
+ pool = popen_pool.PopenPoolExecutor()
lines = [rec for rec in pool.map(decode, lines) if rec is not None]
logger.info("map done %.2f", time.time() - tic)
diff --git a/python/tvm/autotvm/utils.py b/python/tvm/autotvm/utils.py
index fa1dcfd..ec3f18d 100644
--- a/python/tvm/autotvm/utils.py
+++ b/python/tvm/autotvm/utils.py
@@ -17,7 +17,6 @@
# pylint: disable=invalid-name
"""Utilities"""
import logging
-import multiprocessing
import time
from random import randrange
@@ -25,6 +24,7 @@ from random import randrange
import numpy as np
import tvm.arith
from tvm.tir import expr
+from tvm.contrib.popen_pool import PopenPoolExecutor
logger = logging.getLogger("autotvm")
@@ -111,7 +111,7 @@ def pool_map(func, args, batch_size, verbose=False,
pool=None):
ret = None
tic = time.time()
- local_pool = pool or multiprocessing.Pool()
+ local_pool = pool or PopenPoolExecutor()
if verbose:
logger.info("mapping begin")
for i in range(0, len(args), batch_size):
diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py
index 2f55203..68c21ef 100644
--- a/python/tvm/contrib/popen_pool.py
+++ b/python/tvm/contrib/popen_pool.py
@@ -269,9 +269,16 @@ class PopenPoolExecutor:
timeout : float
Timeout value for each function submit.
+ Note
+ ----
+ If max_workers is NONE then the number returned by
+ os.cpu_count() is used. This method aligns with the
+ behavior of multiprocessing.pool().
"""
- def __init__(self, max_workers, timeout=None):
+ def __init__(self, max_workers=None, timeout=None):
+ if max_workers is None:
+ max_workers = os.cpu_count()
# Use an internal thread pool to send to popen workers
self._threadpool =
concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
self._timeout = timeout
diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py
new file mode 100644
index 0000000..bd1ada4
--- /dev/null
+++ b/python/tvm/testing/__init__.py
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin, wildcard-import
+"""Utility Python functions for TVM testing"""
+from .utils import assert_allclose, assert_prim_expr_equal,
check_bool_expr_is_true
+from .utils import check_int_constraints_trans_consistency,
check_numerical_grads
+from .utils import device_enabled, enabled_targets, exclude_targets
+from .utils import fixture, parameter, parameters, parametrize_targets,
uses_gpu
+from .utils import known_failing_targets, requires_cuda, requires_cudagraph
+from .utils import requires_gpu, requires_llvm, requires_rocm, requires_rpc
+from .utils import requires_tensorcore, requires_metal, requires_micro,
requires_opencl
+from .utils import _auto_parametrize_target, _count_num_fixture_uses
+from .utils import _remove_global_fixture_definitions,
_parametrize_correlated_parameters
+from .utils import _pytest_target_params, identity_after, terminate_self
+
+from ._ffi_api import nop, echo, device_test, run_check_signal,
object_use_count
+from ._ffi_api import test_wrap_callback, test_raise_error_callback,
test_check_eq_callback
+from ._ffi_api import ErrorTest, FrontendTestModule
+
+from . import auto_scheduler
diff --git a/python/tvm/testing/_ffi_api.py b/python/tvm/testing/_ffi_api.py
new file mode 100644
index 0000000..56a7722
--- /dev/null
+++ b/python/tvm/testing/_ffi_api.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.testing"""
+import tvm._ffi
+
+
+tvm._ffi._init_api("testing", __name__)
diff --git a/tests/python/unittest/test_auto_scheduler_common.py
b/python/tvm/testing/auto_scheduler.py
similarity index 99%
rename from tests/python/unittest/test_auto_scheduler_common.py
rename to python/tvm/testing/auto_scheduler.py
index 4890268..bc335c8 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/python/tvm/testing/auto_scheduler.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+# pylint: disable=invalid-name, missing-function-docstring
"""Common functions for auto_scheduler test cases"""
import tvm
from tvm import auto_scheduler, te, topi
diff --git a/python/tvm/testing.py b/python/tvm/testing/utils.py
similarity index 99%
rename from python/tvm/testing.py
rename to python/tvm/testing/utils.py
index 9515189..71ab077 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing/utils.py
@@ -1376,6 +1376,3 @@ def identity_after(x, sleep):
def terminate_self():
"""Testing function to terminate the process."""
sys.exit(-1)
-
-
-tvm._ffi._init_api("testing", __name__)
diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py
b/tests/python/unittest/test_auto_scheduler_compute_dag.py
index e394115..81ee5ca 100644
--- a/tests/python/unittest/test_auto_scheduler_compute_dag.py
+++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py
@@ -23,7 +23,7 @@ import tvm
from tvm import topi
from tvm import auto_scheduler, te
-from test_auto_scheduler_common import (
+from tvm.testing.auto_scheduler import (
get_tiled_matmul,
invalid_compute_definition,
matmul_auto_scheduler_test,
diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py
b/tests/python/unittest/test_auto_scheduler_cost_model.py
index 0b34615..50e3ceb 100644
--- a/tests/python/unittest/test_auto_scheduler_cost_model.py
+++ b/tests/python/unittest/test_auto_scheduler_cost_model.py
@@ -24,7 +24,7 @@ import numpy as np
import tvm
from tvm import auto_scheduler
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
def get_sample_records(number):
diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
index e28219d..b5c99c0 100644
--- a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
+++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
@@ -18,7 +18,7 @@
import tvm
import pytest
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
from tvm import auto_scheduler, te
from tvm.auto_scheduler.cost_model.cost_model import PythonBasedModel
diff --git a/tests/python/unittest/test_auto_scheduler_feature.py
b/tests/python/unittest/test_auto_scheduler_feature.py
index 82cfb1d..96090e3 100644
--- a/tests/python/unittest/test_auto_scheduler_feature.py
+++ b/tests/python/unittest/test_auto_scheduler_feature.py
@@ -23,7 +23,7 @@ import tempfile
import tvm
from tvm import te, auto_scheduler
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
def fequal(a, b):
diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
index c929196..39673fa 100644
--- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
+++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
@@ -26,7 +26,7 @@ import tvm.testing
from tvm import topi
from tvm import auto_scheduler, te
-from test_auto_scheduler_common import get_tiled_matmul,
matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import get_tiled_matmul,
matmul_auto_scheduler_test
def test_apply_steps_with_layout_rewrite():
diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py
b/tests/python/unittest/test_auto_scheduler_loop_state.py
index 44ed1fc..0965ed9 100644
--- a/tests/python/unittest/test_auto_scheduler_loop_state.py
+++ b/tests/python/unittest/test_auto_scheduler_loop_state.py
@@ -23,7 +23,7 @@ import tvm
from tvm import auto_scheduler, te
from tvm import topi
-from test_auto_scheduler_common import (
+from tvm.testing.auto_scheduler import (
matmul_auto_scheduler_test,
conv2d_nchw_bn_relu_auto_scheduler_test,
)
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py
b/tests/python/unittest/test_auto_scheduler_measure.py
index 375f816..9eae3dd 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -26,7 +26,7 @@ from tvm import te, auto_scheduler
import tempfile
import tvm.testing
import pickle
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
from tvm.auto_scheduler import workload_registry
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py
b/tests/python/unittest/test_auto_scheduler_search_policy.py
index d114ce4..a9f6596 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -27,7 +27,7 @@ import tvm.testing
from tvm import auto_scheduler
from tvm.auto_scheduler.utils import get_const_tuple
-from test_auto_scheduler_common import (
+from tvm.testing.auto_scheduler import (
matmul_auto_scheduler_test,
zero_rank_compute_auto_scheduler_test,
zero_rank_reduce_auto_scheduler_test,
diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py
b/tests/python/unittest/test_auto_scheduler_search_task.py
index cd47f1e..f23b02c 100644
--- a/tests/python/unittest/test_auto_scheduler_search_task.py
+++ b/tests/python/unittest/test_auto_scheduler_search_task.py
@@ -24,7 +24,7 @@ import tvm
import tvm.testing
from tvm import auto_scheduler
from tvm.auto_scheduler.utils import get_const_tuple
-from test_auto_scheduler_common import (
+from tvm.testing.auto_scheduler import (
matmul_auto_scheduler_test,
zero_rank_compute_auto_scheduler_test,
zero_rank_reduce_auto_scheduler_test,
diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
index 4092ae0..6d2f870 100644
--- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
+++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
@@ -27,7 +27,7 @@ from tvm import te, auto_scheduler
from tvm.auto_scheduler import _ffi_api
from tvm.auto_scheduler.loop_state import Stage
-from test_auto_scheduler_common import (
+from tvm.testing.auto_scheduler import (
matmul_auto_scheduler_test,
double_matmul_auto_scheduler_test,
conv2d_nchw_bn_relu_auto_scheduler_test,
diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py
b/tests/python/unittest/test_auto_scheduler_task_scheduler.py
index bbe29b1..a3f3569 100644
--- a/tests/python/unittest/test_auto_scheduler_task_scheduler.py
+++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py
@@ -25,7 +25,7 @@ import tvm
import tvm.testing
from tvm import auto_scheduler
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
@tvm.testing.requires_llvm