This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4dd7f68  [TIR] Use PopenPool instead of multiprocessing.pool (#8492)
4dd7f68 is described below

commit 4dd7f6806f05bbc9c33d68d68493113534a34b12
Author: Yuanjing Shi <[email protected]>
AuthorDate: Thu Aug 12 13:35:04 2021 -0700

    [TIR] Use PopenPool instead of multiprocessing.pool (#8492)
    
    Co-authored-by: Wuwei Lin <[email protected]>
---
 python/tvm/auto_scheduler/measure.py               | 216 ++++++++++++---------
 python/tvm/auto_scheduler/utils.py                 |  43 +---
 python/tvm/autotvm/record.py                       |   4 +-
 python/tvm/autotvm/utils.py                        |   4 +-
 python/tvm/contrib/popen_pool.py                   |   9 +-
 python/tvm/testing/__init__.py                     |  34 ++++
 python/tvm/testing/_ffi_api.py                     |  21 ++
 .../tvm/testing/auto_scheduler.py                  |   2 +-
 python/tvm/{testing.py => testing/utils.py}        |   3 -
 .../unittest/test_auto_scheduler_compute_dag.py    |   2 +-
 .../unittest/test_auto_scheduler_cost_model.py     |   2 +-
 .../test_auto_scheduler_evolutionary_search.py     |   2 +-
 .../python/unittest/test_auto_scheduler_feature.py |   2 +-
 .../unittest/test_auto_scheduler_layout_rewrite.py |   2 +-
 .../unittest/test_auto_scheduler_loop_state.py     |   2 +-
 .../python/unittest/test_auto_scheduler_measure.py |   2 +-
 .../unittest/test_auto_scheduler_search_policy.py  |   2 +-
 .../unittest/test_auto_scheduler_search_task.py    |   2 +-
 .../test_auto_scheduler_sketch_generation.py       |   2 +-
 .../unittest/test_auto_scheduler_task_scheduler.py |   2 +-
 20 files changed, 205 insertions(+), 153 deletions(-)

diff --git a/python/tvm/auto_scheduler/measure.py 
b/python/tvm/auto_scheduler/measure.py
index 8d76260..a202e83 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -44,6 +44,7 @@ from tvm.driver import build_module
 from tvm.ir import transform
 from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
 from tvm.contrib import tar, ndk
+from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind
 from tvm.target import Target
 
 
@@ -599,7 +600,7 @@ class MeasureErrorNo(object):
     UNKNOWN_ERROR = 8  # Unknown error
 
 
-def _timed_func(inp_serialized, build_func, verbose):
+def _local_build_worker(inp_serialized, build_func, verbose):
     tic = time.time()
     inp = MeasureInput.deserialize(inp_serialized)
     task = inp.task
@@ -664,15 +665,13 @@ def local_build_worker(args):
     )
     build_func = BuildFunc.build_func
 
-    res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, 
verbose))
-    if isinstance(res, TimeoutError):
-        if verbose >= 1:
-            print(".T", end="", flush=True)  # Build timeout
-        res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
-    elif isinstance(res, Exception):
+    try:
+        res = _local_build_worker(inp, build_func, verbose)
+    # pylint: disable=broad-except
+    except Exception:
         if verbose >= 1:
             print(".E", end="", flush=True)  # Build error
-        res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout
+        res = None, [], MeasureErrorNo.COMPILE_HOST, make_traceback_info(), 
timeout
 
     return res
 
@@ -701,9 +700,8 @@ def local_builder_build(inputs, timeout, n_parallel, 
build_func="default", verbo
     res : List[BuildResult]
         The build results of these MeasureInputs.
     """
-    # This pool is not doing computationally intensive work, so we can use 
threads
-    pool = multiprocessing.pool.ThreadPool(n_parallel)
-    tuple_res = pool.map(
+    executor = PopenPoolExecutor(n_parallel, timeout)
+    tuple_res = executor.map_with_error_catching(
         local_build_worker,
         [
             (
@@ -715,13 +713,16 @@ def local_builder_build(inputs, timeout, n_parallel, 
build_func="default", verbo
             for i in inputs
         ],
     )
-    pool.terminate()
-    pool.join()
-    del pool
 
     results = []
     for res in tuple_res:
-        results.append(BuildResult(*res))
+        if res.status == StatusKind.COMPLETE:
+            results.append(BuildResult(*res.value))
+        else:
+            assert res.status == StatusKind.TIMEOUT
+            if verbose >= 1:
+                print(".T", end="", flush=True)  # Build timeout
+            results.append(BuildResult(None, [], MeasureErrorNo.BUILD_TIMEOUT, 
None, timeout))
 
     return results
 
@@ -817,9 +818,58 @@ def prepare_input_map(args):
     return tensor_input_map
 
 
+def prepare_runner_args(inp, build_res):
+    """This function prepares the pre-defined arguments in 
`TASK_INPUT_BUFFER_TABLE` for local/rpc
+    runner in main process
+
+    Parameters
+    ----------
+    inp : MeasureInput
+        Measure input to be measured.
+
+    build_res : BuildResult
+        Build result to be measured.
+
+    Returns
+    -------
+    List[Optional[numpy.ndarray]] :
+        List of arguments for running the program. If the argument does not 
have a pre-defined input
+        buffer, None is added to the list as a placeholder.
+
+    """
+    # pylint: disable=import-outside-toplevel
+    from .search_task import get_task_input_buffer  # lazily import to avoid 
recursive dependency
+
+    task_input_names = inp.task.task_input_names
+    tensor_input_map = prepare_input_map(build_res.args)
+    if not task_input_names:
+        tensor_input_map = {}
+    args = []
+    task_inputs_count = 0
+    for arg in build_res.args:
+        if arg in tensor_input_map:
+            tensor_name = tensor_input_map[arg]
+            if tensor_name in task_input_names:
+                task_input_buffer = 
get_task_input_buffer(inp.task.workload_key, tensor_name)
+                # convert tvm.NDArray to picklable numpy.ndarray
+                args.append(task_input_buffer.numpy())
+                task_inputs_count += 1
+            else:
+                raise ValueError(
+                    "%s not found in task_inputs, " % (tensor_name)
+                    + "should provide with `SearchTask(..., 
task_inputs={...})`"
+                )
+        else:
+            args.append(None)
+    if task_inputs_count != len(task_input_names):
+        raise RuntimeError("task_inputs not fully matched, check if there's 
any unexpected error")
+    return args
+
+
 def _timed_eval_func(
     inp_serialized,
     build_res,
+    args,
     number,
     repeat,
     min_repeat_ms,
@@ -827,11 +877,7 @@ def _timed_eval_func(
     enable_cpu_cache_flush,
     verbose,
 ):
-    # pylint: disable=import-outside-toplevel
-    from .search_task import get_task_input_buffer  # lazily import to avoid 
recursive dependency
-
     inp = MeasureInput.deserialize(inp_serialized)
-    task_input_names = inp.task.task_input_names
     tic = time.time()
     error_no = 0
     error_msg = None
@@ -862,33 +908,18 @@ def _timed_eval_func(
         try:
             random_fill = 
tvm.get_global_func("tvm.contrib.random.random_fill", True)
             assert random_fill, "Please make sure USE_RANDOM is ON in the 
config.cmake"
-
-            tensor_input_map = prepare_input_map(build_res.args) if 
task_input_names else {}
-            args = []
-            task_inputs_count = 0
-            for arg in build_res.args:
-                if arg in tensor_input_map:
-                    tensor_name = tensor_input_map[arg]
-                    if tensor_name in task_input_names:
-                        args.append(
-                            ndarray.array(
-                                get_task_input_buffer(inp.task.workload_key, 
tensor_name), dev
-                            )
-                        )
-                        task_inputs_count += 1
-                    else:
-                        raise ValueError(
-                            "%s not found in task_inputs, " % (tensor_name)
-                            + "should provide with `SearchTask(..., 
task_inputs={...})`"
-                        )
-                else:
-                    empty_array = ndarray.empty(get_const_tuple(arg.shape), 
arg.dtype, dev)
+            assert len(args) == len(build_res.args)
+            # pylint: disable=consider-using-enumerate
+            for idx in range(len(args)):
+                if args[idx] is None:
+                    build_res_arg = build_res.args[idx]
+                    empty_array = ndarray.empty(
+                        get_const_tuple(build_res_arg.shape), 
build_res_arg.dtype, dev
+                    )
                     random_fill(empty_array)
-                    args.append(empty_array)
-            if task_inputs_count != len(task_input_names):
-                raise RuntimeError(
-                    "task_inputs not fully matched, check if there's any 
unexpected error"
-                )
+                    args[idx] = empty_array
+                else:
+                    args[idx] = ndarray.array(args[idx], dev)
             dev.sync()
             costs = time_f(*args).results
         # pylint: disable=broad-except
@@ -968,6 +999,7 @@ def local_run(
 
     measure_results = []
     assert len(inputs) == len(build_results), "Measure input size should be 
equal to build results"
+    worker = PopenWorker()
     for inp, build_res in zip(inputs, build_results):
         if build_res.error_no != 0:
             res = (
@@ -978,12 +1010,15 @@ def local_run(
                 time.time(),
             )
         else:
+            args = prepare_runner_args(inp, build_res)
             res = call_func_with_timeout(
+                worker,
                 timeout,
                 _timed_eval_func,
                 args=(
                     inp.serialize(),
                     build_res,
+                    args,
                     number,
                     repeat,
                     min_repeat_ms,
@@ -991,7 +1026,6 @@ def local_run(
                     enable_cpu_cache_flush,
                     verbose,
                 ),
-                add_thread_wrapper=True,
             )
             if isinstance(res, TimeoutError):
                 if verbose >= 1:
@@ -1022,9 +1056,10 @@ def local_run(
     return measure_results
 
 
-def _timed_rpc_run(
+def _rpc_run(
     inp_serialized,
     build_res,
+    args,
     key,
     host,
     port,
@@ -1037,11 +1072,7 @@ def _timed_rpc_run(
     enable_cpu_cache_flush,
     verbose,
 ):
-    # pylint: disable=import-outside-toplevel
-    from .search_task import get_task_input_buffer  # lazily import to avoid 
recursive dependency
-
     inp = MeasureInput.deserialize(inp_serialized)
-    task_input_names = inp.task.task_input_names
     tic = time.time()
     error_no = 0
     error_msg = None
@@ -1080,32 +1111,18 @@ def _timed_rpc_run(
                 random_fill
             ), "Please make sure USE_RANDOM is ON in the config.cmake on the 
remote devices"
 
-            tensor_input_map = prepare_input_map(build_res.args) if 
task_input_names else {}
-            args = []
-            task_inputs_count = 0
-            for arg in build_res.args:
-                if arg in tensor_input_map:
-                    tensor_name = tensor_input_map[arg]
-                    if tensor_name in task_input_names:
-                        args.append(
-                            ndarray.array(
-                                get_task_input_buffer(inp.task.workload_key, 
tensor_name), dev
-                            )
-                        )
-                        task_inputs_count += 1
-                    else:
-                        raise ValueError(
-                            "%s not found in task_inputs, " % (tensor_name)
-                            + "should provide with `SearchTask(..., 
task_inputs={...})`"
-                        )
-                else:
-                    empty_array = ndarray.empty(get_const_tuple(arg.shape), 
arg.dtype, dev)
+            assert len(args) == len(build_res.args)
+            # pylint: disable=consider-using-enumerate
+            for idx in range(len(args)):
+                if args[idx] is None:
+                    build_res_arg = build_res.args[idx]
+                    empty_array = ndarray.empty(
+                        get_const_tuple(build_res_arg.shape), 
build_res_arg.dtype, dev
+                    )
                     random_fill(empty_array)
-                    args.append(empty_array)
-            if task_inputs_count != len(task_input_names):
-                logger.warning(
-                    "task_inputs not fully matched, check if there's any 
unexpected error"
-                )
+                    args[idx] = empty_array
+                else:
+                    args[idx] = ndarray.array(args[idx], dev)
             dev.sync()
 
             # First run for check that the kernel is correct
@@ -1152,7 +1169,7 @@ def _rpc_run_worker(args):
     res : MeasureResult
         The measure result of this Runner thread.
     """
-    _, build_res, _, _, _, _, timeout, _, _, _, _, _, verbose = args
+    _, build_res, _, _, _, _, _, timeout, _, _, _, _, _, verbose = args
     if build_res.error_no != MeasureErrorNo.NO_ERROR:
         return (
             (MAX_FLOAT,),
@@ -1162,24 +1179,16 @@ def _rpc_run_worker(args):
             time.time(),
         )
 
-    res = call_func_with_timeout(timeout, _timed_rpc_run, args=args)
-    if isinstance(res, TimeoutError):
-        if verbose >= 1:
-            print("*T", end="")  # Run timeout
-        res = (
-            (MAX_FLOAT,),
-            MeasureErrorNo.RUN_TIMEOUT,
-            None,
-            build_res.time_cost + timeout,
-            time.time(),
-        )
-    elif isinstance(res, Exception):
+    try:
+        res = _rpc_run(*args)
+    # pylint: disable=broad-except
+    except Exception:
         if verbose >= 1:
             print("*E", end="")  # Run error
         res = (
             (MAX_FLOAT,),
             MeasureErrorNo.RUNTIME_DEVICE,
-            str(res),
+            make_traceback_info(),
             build_res.time_cost + timeout,
             time.time(),
         )
@@ -1259,13 +1268,14 @@ def rpc_runner_run(
     """
     assert len(inputs) == len(build_results), "Measure input size should be 
equal to build results"
     # This pool is not doing computationally intensive work, so we can use 
threads
-    pool = multiprocessing.pool.ThreadPool(n_parallel)
-    tuple_res = pool.map(
+    executor = PopenPoolExecutor(n_parallel)
+    tuple_res = executor.map_with_error_catching(
         _rpc_run_worker,
         [
             (
                 inp.serialize(),
                 build_res,
+                prepare_runner_args(inp, build_res),
                 key,
                 host,
                 port,
@@ -1281,13 +1291,25 @@ def rpc_runner_run(
             for inp, build_res in zip(inputs, build_results)
         ],
     )
-    pool.terminate()
-    pool.join()
-    del pool
 
     results = []
-    for res in tuple_res:
-        results.append(MeasureResult(*res))
+    for i, res in enumerate(tuple_res):
+        if res.status == StatusKind.COMPLETE:
+            results.append(MeasureResult(*res.value))
+        else:
+            assert res.status == StatusKind.TIMEOUT
+            if verbose >= 1:
+                print("*T", end="")  # Run timeout
+            build_res = build_results[i]
+            results.append(
+                MeasureResult(
+                    (MAX_FLOAT,),
+                    MeasureErrorNo.RUN_TIMEOUT,
+                    None,
+                    build_res.time_cost + timeout,
+                    time.time(),
+                )
+            )
 
     if verbose >= 1:
         print("")
diff --git a/python/tvm/auto_scheduler/utils.py 
b/python/tvm/auto_scheduler/utils.py
index 1c03491..9919bcb 100644
--- a/python/tvm/auto_scheduler/utils.py
+++ b/python/tvm/auto_scheduler/utils.py
@@ -20,9 +20,6 @@
 
 from typing import Hashable
 import json
-import multiprocessing
-import multiprocessing.pool
-import queue
 import signal
 import threading
 import traceback
@@ -289,41 +286,15 @@ def call_func_with_thread(func, args, kwargs):
     return res[0]
 
 
-def _func_wrapper(que, func, args, kwargs, add_thread_wrapper):
-    """Call function and return the result over the queue."""
-    try:
-        if add_thread_wrapper:
-            # Add a new layer of threadinng to avoid the conflict between
-            # python's multiprocessing and tvm's thread pool.
-            res = call_func_with_thread(func, args, kwargs)
-        else:
-            res = func(*args, **kwargs)
-        que.put(res)
-    except Exception:  # pylint: disable=broad-except
-        que.put(Exception(make_traceback_info()))
-
-
-def call_func_with_timeout(timeout, func, args=(), kwargs=None, 
add_thread_wrapper=False):
+def call_func_with_timeout(
+    worker, timeout, func, args=(), kwargs=None
+):  # pylint: disable=unused-argument
     """Call a function with timeout"""
-    que = multiprocessing.Queue(2)
-    process = multiprocessing.Process(
-        target=_func_wrapper, args=(que, func, args, kwargs or {}, 
add_thread_wrapper)
-    )
-    process.start()
-
+    worker.send(func, args, kwargs, timeout)
     try:
-        res = que.get(timeout=timeout)
-    except queue.Empty:
-        res = TimeoutError()
-
-    # clean queue and process
-    kill_child_processes(process.pid)
-    process.terminate()
-    process.join()
-    que.close()
-    que.join_thread()
-    del process
-    del que
+        res = worker.recv()
+    except Exception:  # pylint: disable=broad-except
+        res = Exception(make_traceback_info())
 
     return res
 
diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py
index 4f11aea..8145563 100644
--- a/python/tvm/autotvm/record.py
+++ b/python/tvm/autotvm/record.py
@@ -21,7 +21,6 @@
 import argparse
 import base64
 import logging
-import multiprocessing
 import pickle
 import json
 import time
@@ -32,6 +31,7 @@ import numpy as np
 
 from .. import build, lower
 from ..target import Target
+from ..contrib import popen_pool
 from .. import __version__
 from . import task
 from .task import ConfigEntity, ApplyHistoryBest
@@ -230,7 +230,7 @@ def split_workload(in_file, clean=True):
     lines = list(open(in_file).readlines())
 
     logger.info("start converting...")
-    pool = multiprocessing.Pool()
+    pool = popen_pool.PopenPoolExecutor()
     lines = [rec for rec in pool.map(decode, lines) if rec is not None]
     logger.info("map done %.2f", time.time() - tic)
 
diff --git a/python/tvm/autotvm/utils.py b/python/tvm/autotvm/utils.py
index fa1dcfd..ec3f18d 100644
--- a/python/tvm/autotvm/utils.py
+++ b/python/tvm/autotvm/utils.py
@@ -17,7 +17,6 @@
 # pylint: disable=invalid-name
 """Utilities"""
 import logging
-import multiprocessing
 import time
 
 from random import randrange
@@ -25,6 +24,7 @@ from random import randrange
 import numpy as np
 import tvm.arith
 from tvm.tir import expr
+from tvm.contrib.popen_pool import PopenPoolExecutor
 
 logger = logging.getLogger("autotvm")
 
@@ -111,7 +111,7 @@ def pool_map(func, args, batch_size, verbose=False, 
pool=None):
 
     ret = None
     tic = time.time()
-    local_pool = pool or multiprocessing.Pool()
+    local_pool = pool or PopenPoolExecutor()
     if verbose:
         logger.info("mapping begin")
     for i in range(0, len(args), batch_size):
diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py
index 2f55203..68c21ef 100644
--- a/python/tvm/contrib/popen_pool.py
+++ b/python/tvm/contrib/popen_pool.py
@@ -269,9 +269,16 @@ class PopenPoolExecutor:
 
     timeout : float
         Timeout value for each function submit.
+    Note
+    ----
+    If max_workers is NONE then the number returned by
+    os.cpu_count() is used. This method aligns with the
+    behavior of multiprocessing.pool().
     """
 
-    def __init__(self, max_workers, timeout=None):
+    def __init__(self, max_workers=None, timeout=None):
+        if max_workers is None:
+            max_workers = os.cpu_count()
         # Use an internal thread pool to send to popen workers
         self._threadpool = 
concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
         self._timeout = timeout
diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py
new file mode 100644
index 0000000..bd1ada4
--- /dev/null
+++ b/python/tvm/testing/__init__.py
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin, wildcard-import
+"""Utility Python functions for TVM testing"""
+from .utils import assert_allclose, assert_prim_expr_equal, 
check_bool_expr_is_true
+from .utils import check_int_constraints_trans_consistency, 
check_numerical_grads
+from .utils import device_enabled, enabled_targets, exclude_targets
+from .utils import fixture, parameter, parameters, parametrize_targets, 
uses_gpu
+from .utils import known_failing_targets, requires_cuda, requires_cudagraph
+from .utils import requires_gpu, requires_llvm, requires_rocm, requires_rpc
+from .utils import requires_tensorcore, requires_metal, requires_micro, 
requires_opencl
+from .utils import _auto_parametrize_target, _count_num_fixture_uses
+from .utils import _remove_global_fixture_definitions, 
_parametrize_correlated_parameters
+from .utils import _pytest_target_params, identity_after, terminate_self
+
+from ._ffi_api import nop, echo, device_test, run_check_signal, 
object_use_count
+from ._ffi_api import test_wrap_callback, test_raise_error_callback, 
test_check_eq_callback
+from ._ffi_api import ErrorTest, FrontendTestModule
+
+from . import auto_scheduler
diff --git a/python/tvm/testing/_ffi_api.py b/python/tvm/testing/_ffi_api.py
new file mode 100644
index 0000000..56a7722
--- /dev/null
+++ b/python/tvm/testing/_ffi_api.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.testing"""
+import tvm._ffi
+
+
+tvm._ffi._init_api("testing", __name__)
diff --git a/tests/python/unittest/test_auto_scheduler_common.py 
b/python/tvm/testing/auto_scheduler.py
similarity index 99%
rename from tests/python/unittest/test_auto_scheduler_common.py
rename to python/tvm/testing/auto_scheduler.py
index 4890268..bc335c8 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/python/tvm/testing/auto_scheduler.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+# pylint: disable=invalid-name, missing-function-docstring
 """Common functions for auto_scheduler test cases"""
 import tvm
 from tvm import auto_scheduler, te, topi
diff --git a/python/tvm/testing.py b/python/tvm/testing/utils.py
similarity index 99%
rename from python/tvm/testing.py
rename to python/tvm/testing/utils.py
index 9515189..71ab077 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing/utils.py
@@ -1376,6 +1376,3 @@ def identity_after(x, sleep):
 def terminate_self():
     """Testing function to terminate the process."""
     sys.exit(-1)
-
-
-tvm._ffi._init_api("testing", __name__)
diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py 
b/tests/python/unittest/test_auto_scheduler_compute_dag.py
index e394115..81ee5ca 100644
--- a/tests/python/unittest/test_auto_scheduler_compute_dag.py
+++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py
@@ -23,7 +23,7 @@ import tvm
 from tvm import topi
 from tvm import auto_scheduler, te
 
-from test_auto_scheduler_common import (
+from tvm.testing.auto_scheduler import (
     get_tiled_matmul,
     invalid_compute_definition,
     matmul_auto_scheduler_test,
diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py 
b/tests/python/unittest/test_auto_scheduler_cost_model.py
index 0b34615..50e3ceb 100644
--- a/tests/python/unittest/test_auto_scheduler_cost_model.py
+++ b/tests/python/unittest/test_auto_scheduler_cost_model.py
@@ -24,7 +24,7 @@ import numpy as np
 import tvm
 from tvm import auto_scheduler
 
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
 
 
 def get_sample_records(number):
diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py 
b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
index e28219d..b5c99c0 100644
--- a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
+++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py
@@ -18,7 +18,7 @@
 
 import tvm
 import pytest
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
 from tvm import auto_scheduler, te
 from tvm.auto_scheduler.cost_model.cost_model import PythonBasedModel
 
diff --git a/tests/python/unittest/test_auto_scheduler_feature.py 
b/tests/python/unittest/test_auto_scheduler_feature.py
index 82cfb1d..96090e3 100644
--- a/tests/python/unittest/test_auto_scheduler_feature.py
+++ b/tests/python/unittest/test_auto_scheduler_feature.py
@@ -23,7 +23,7 @@ import tempfile
 import tvm
 from tvm import te, auto_scheduler
 
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
 
 
 def fequal(a, b):
diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py 
b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
index c929196..39673fa 100644
--- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
+++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py
@@ -26,7 +26,7 @@ import tvm.testing
 from tvm import topi
 from tvm import auto_scheduler, te
 
-from test_auto_scheduler_common import get_tiled_matmul, 
matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import get_tiled_matmul, 
matmul_auto_scheduler_test
 
 
 def test_apply_steps_with_layout_rewrite():
diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py 
b/tests/python/unittest/test_auto_scheduler_loop_state.py
index 44ed1fc..0965ed9 100644
--- a/tests/python/unittest/test_auto_scheduler_loop_state.py
+++ b/tests/python/unittest/test_auto_scheduler_loop_state.py
@@ -23,7 +23,7 @@ import tvm
 from tvm import auto_scheduler, te
 from tvm import topi
 
-from test_auto_scheduler_common import (
+from tvm.testing.auto_scheduler import (
     matmul_auto_scheduler_test,
     conv2d_nchw_bn_relu_auto_scheduler_test,
 )
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py 
b/tests/python/unittest/test_auto_scheduler_measure.py
index 375f816..9eae3dd 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -26,7 +26,7 @@ from tvm import te, auto_scheduler
 import tempfile
 import tvm.testing
 import pickle
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
 from tvm.auto_scheduler import workload_registry
 
 
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py 
b/tests/python/unittest/test_auto_scheduler_search_policy.py
index d114ce4..a9f6596 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -27,7 +27,7 @@ import tvm.testing
 from tvm import auto_scheduler
 from tvm.auto_scheduler.utils import get_const_tuple
 
-from test_auto_scheduler_common import (
+from tvm.testing.auto_scheduler import (
     matmul_auto_scheduler_test,
     zero_rank_compute_auto_scheduler_test,
     zero_rank_reduce_auto_scheduler_test,
diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py 
b/tests/python/unittest/test_auto_scheduler_search_task.py
index cd47f1e..f23b02c 100644
--- a/tests/python/unittest/test_auto_scheduler_search_task.py
+++ b/tests/python/unittest/test_auto_scheduler_search_task.py
@@ -24,7 +24,7 @@ import tvm
 import tvm.testing
 from tvm import auto_scheduler
 from tvm.auto_scheduler.utils import get_const_tuple
-from test_auto_scheduler_common import (
+from tvm.testing.auto_scheduler import (
     matmul_auto_scheduler_test,
     zero_rank_compute_auto_scheduler_test,
     zero_rank_reduce_auto_scheduler_test,
diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py 
b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
index 4092ae0..6d2f870 100644
--- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py
+++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py
@@ -27,7 +27,7 @@ from tvm import te, auto_scheduler
 from tvm.auto_scheduler import _ffi_api
 from tvm.auto_scheduler.loop_state import Stage
 
-from test_auto_scheduler_common import (
+from tvm.testing.auto_scheduler import (
     matmul_auto_scheduler_test,
     double_matmul_auto_scheduler_test,
     conv2d_nchw_bn_relu_auto_scheduler_test,
diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py 
b/tests/python/unittest/test_auto_scheduler_task_scheduler.py
index bbe29b1..a3f3569 100644
--- a/tests/python/unittest/test_auto_scheduler_task_scheduler.py
+++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py
@@ -25,7 +25,7 @@ import tvm
 import tvm.testing
 from tvm import auto_scheduler
 
-from test_auto_scheduler_common import matmul_auto_scheduler_test
+from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
 
 
 @tvm.testing.requires_llvm

Reply via email to