This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ac4815c  [AutoScheduler] Allow device specification for AutoScheduler 
Runners. (#10123)
ac4815c is described below

commit ac4815c0d017284bc474fdad12fc20186a4357b8
Author: Josh Fromm <[email protected]>
AuthorDate: Tue Feb 1 15:09:27 2022 -0800

    [AutoScheduler] Allow device specification for AutoScheduler Runners. 
(#10123)
    
    * Changed the python api to support device.
    
    * Finished implementation and updated tests.
    
    * Fix typo.
---
 include/tvm/auto_scheduler/measure.h             |  8 +++++--
 python/tvm/auto_scheduler/measure.py             | 28 +++++++++++++++++++++---
 src/auto_scheduler/measure.cc                    | 18 ++++++++-------
 tests/python/relay/test_auto_scheduler_tuning.py |  2 +-
 4 files changed, 42 insertions(+), 14 deletions(-)

diff --git a/include/tvm/auto_scheduler/measure.h 
b/include/tvm/auto_scheduler/measure.h
index 20a93e2..8576468 100755
--- a/include/tvm/auto_scheduler/measure.h
+++ b/include/tvm/auto_scheduler/measure.h
@@ -308,6 +308,8 @@ class ProgramRunnerNode : public Object {
   double cooldown_interval;
   /*! \brief Whether to flush cache on CPU between repeated measurements. */
   bool enable_cpu_cache_flush;
+  /*! \brief Which device to run on if multiple are avaialble. */
+  int device;
 
   /*!
    * \brief Run measurement and return results.
@@ -391,9 +393,10 @@ class LocalRunner : public ProgramRunner {
    * \param min_repeat_ms The minimum duration of one repeat in milliseconds.
    * \param cooldown_interval The cool down interval between two measurements.
    * \param enable_cpu_cache_flush Whether to flush cache on CPU between 
repeated measurements.
+   * \param device Which device to run on if multiple are available.
    */
   LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double 
cooldown_interval,
-              bool enable_cpu_cache_flush);
+              bool enable_cpu_cache_flush, int device);
 
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, ProgramRunner, 
LocalRunnerNode);
 };
@@ -443,10 +446,11 @@ class RPCRunner : public ProgramRunner {
    * \param min_repeat_ms The minimum duration of one repeat in milliseconds.
    * \param cooldown_interval The cool down interval between two measurements.
    * \param enable_cpu_cache_flush Whether to flush cache on CPU between 
repeated measurements.
+   * \param device Which device to run on if multiple are available.
    */
   RPCRunner(const String& key, const String& host, int port, int priority, int 
n_parallel,
             int timeout, int number, int repeat, int min_repeat_ms, double 
cooldown_interval,
-            bool enable_cpu_cache_flush);
+            bool enable_cpu_cache_flush, int device);
 
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, ProgramRunner, 
RPCRunnerNode);
 };
diff --git a/python/tvm/auto_scheduler/measure.py 
b/python/tvm/auto_scheduler/measure.py
index dd23287..4148cdb 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -382,6 +382,8 @@ class LocalRunner(ProgramRunner):
         its actual latency during end-to-end inference.
         To make this option effective, the argument `number` should also be 
set to 1.
         This is only has effect on CPU task.
+    device: int = 0
+        Which device to run on if multiple are available.
     """
 
     def __init__(
@@ -392,6 +394,7 @@ class LocalRunner(ProgramRunner):
         min_repeat_ms=100,
         cooldown_interval=0.0,
         enable_cpu_cache_flush=False,
+        device=0,
     ):
         if enable_cpu_cache_flush:
             number = 1
@@ -405,6 +408,7 @@ class LocalRunner(ProgramRunner):
             min_repeat_ms,
             cooldown_interval,
             enable_cpu_cache_flush,
+            device,
         )
 
 
@@ -453,6 +457,8 @@ class RPCRunner(ProgramRunner):
         its actual latency during end-to-end inference.
         To make this option effective, the argument `number` should also be 
set to 1.
         This is only has effect on CPU task.
+    device: int = 0
+        Which device to run on if multiple are available.
     """
 
     def __init__(
@@ -468,6 +474,7 @@ class RPCRunner(ProgramRunner):
         min_repeat_ms=100,
         cooldown_interval=0.0,
         enable_cpu_cache_flush=False,
+        device=0,
     ):
         self.__init_handle_by_constructor__(
             _ffi_api.RPCRunner,
@@ -482,6 +489,7 @@ class RPCRunner(ProgramRunner):
             min_repeat_ms,
             cooldown_interval,
             enable_cpu_cache_flush,
+            device,
         )
 
         if check_remote(key, host, port, priority, timeout):
@@ -532,6 +540,8 @@ class LocalRPCMeasureContext:
         its actual latency during end-to-end inference.
         To make this option effective, the argument `number` should also be 
set to 1.
         This is only has effect on CPU task.
+    device: int = 0
+        Which device to run on if multiple are available.
     """
 
     def __init__(
@@ -544,6 +554,7 @@ class LocalRPCMeasureContext:
         min_repeat_ms=0,
         cooldown_interval=0.0,
         enable_cpu_cache_flush=False,
+        device=0,
     ):
         # pylint: disable=import-outside-toplevel
         from tvm.rpc.tracker import Tracker
@@ -570,6 +581,7 @@ class LocalRPCMeasureContext:
             min_repeat_ms,
             cooldown_interval,
             enable_cpu_cache_flush,
+            device,
         )
         # Wait for the processes to start
         time.sleep(0.5)
@@ -871,6 +883,7 @@ def _timed_eval_func(
     cooldown_interval,
     enable_cpu_cache_flush,
     verbose,
+    device,
 ):
     inp = MeasureInput.deserialize(inp_serialized)
     tic = time.time()
@@ -878,7 +891,7 @@ def _timed_eval_func(
     error_msg = None
     try:
         func = module.load_module(build_res.filename)
-        dev = ndarray.device(str(inp.task.target), 0)
+        dev = ndarray.device(str(inp.task.target), device)
         # Limitation:
         # We can not get PackFunction directly in the remote mode as it is 
wrapped
         # under the std::function. We could lift the restriction later once we 
fold
@@ -947,6 +960,7 @@ def local_run(
     cooldown_interval=0,
     enable_cpu_cache_flush=False,
     verbose=1,
+    device=0,
 ):
     """
     Run function of LocalRunner to test the performance of the input 
BuildResults.
@@ -986,6 +1000,8 @@ def local_run(
         This is only has effect on CPU task.
     verbose: int = 1
         Verbosity level. 0 for silent, 1 to output information during program 
measuring.
+    device: int = 0
+        Which device to run on if multiple are available.
 
     Returns
     -------
@@ -1021,6 +1037,7 @@ def local_run(
                     cooldown_interval,
                     enable_cpu_cache_flush,
                     verbose,
+                    device,
                 ),
             )
             if isinstance(res, TimeoutError):
@@ -1067,6 +1084,7 @@ def _rpc_run(
     cooldown_interval,
     enable_cpu_cache_flush,
     verbose,
+    device,
 ):
     inp = MeasureInput.deserialize(inp_serialized)
     tic = time.time()
@@ -1077,7 +1095,7 @@ def _rpc_run(
         remote = request_remote(key, host, port, priority, timeout)
         remote.upload(build_res.filename)
         func = remote.load_module(os.path.split(build_res.filename)[1])
-        dev = remote.device(str(inp.task.target), 0)
+        dev = remote.device(str(inp.task.target), device)
         # Limitation:
         # We can not get PackFunction directly in the remote mode as it is 
wrapped
         # under the std::function. We could lift the restriction later once we 
fold
@@ -1166,7 +1184,7 @@ def _rpc_run_worker(args):
     res : MeasureResult
         The measure result of this Runner thread.
     """
-    _, build_res, _, _, _, _, _, timeout, _, _, _, _, _, verbose = args
+    _, build_res, _, _, _, _, _, timeout, _, _, _, _, _, verbose, _ = args
     if build_res.error_no != MeasureErrorNo.NO_ERROR:
         return (
             (MAX_FLOAT,),
@@ -1209,6 +1227,7 @@ def rpc_runner_run(
     cooldown_interval=0.0,
     enable_cpu_cache_flush=False,
     verbose=1,
+    device=0,
 ):
     """Run function of RPCRunner to test the performance of the input 
BuildResults.
 
@@ -1257,6 +1276,8 @@ def rpc_runner_run(
         This is only has effect on CPU task.
     verbose: int = 1
         Verbosity level. 0 for silent, 1 to output information during program 
measuring.
+    device: int = 0
+        Which device to run on if multiple are available.
 
     Returns
     -------
@@ -1284,6 +1305,7 @@ def rpc_runner_run(
                 cooldown_interval,
                 enable_cpu_cache_flush,
                 verbose,
+                device,
             )
             for inp, build_res in zip(inputs, build_results)
         ],
diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc
index c3212f2..abb7758 100755
--- a/src/auto_scheduler/measure.cc
+++ b/src/auto_scheduler/measure.cc
@@ -127,7 +127,7 @@ Array<BuildResult> LocalBuilderNode::Build(const 
Array<MeasureInput>& inputs, in
 
 /********** LocalRunner **********/
 LocalRunner::LocalRunner(int timeout, int number, int repeat, int 
min_repeat_ms,
-                         double cooldown_interval, bool 
enable_cpu_cache_flush) {
+                         double cooldown_interval, bool 
enable_cpu_cache_flush, int device) {
   ObjectPtr<LocalRunnerNode> node = make_object<LocalRunnerNode>();
   node->timeout = timeout;
   node->number = number;
@@ -135,6 +135,7 @@ LocalRunner::LocalRunner(int timeout, int number, int 
repeat, int min_repeat_ms,
   node->min_repeat_ms = min_repeat_ms;
   node->cooldown_interval = cooldown_interval;
   node->enable_cpu_cache_flush = enable_cpu_cache_flush;
+  node->device = device;
   data_ = std::move(node);
 }
 
@@ -143,7 +144,7 @@ Array<MeasureResult> LocalRunnerNode::Run(const 
Array<MeasureInput>& inputs,
   if (const auto* f = 
runtime::Registry::Get("auto_scheduler.local_runner.run")) {
     Array<MeasureResult> results =
         (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, 
cooldown_interval,
-             enable_cpu_cache_flush, verbose);
+             enable_cpu_cache_flush, verbose, device);
     return results;
   }
   LOG(FATAL) << "auto_scheduler.local_runner.run is not registered. "
@@ -155,7 +156,7 @@ Array<MeasureResult> LocalRunnerNode::Run(const 
Array<MeasureInput>& inputs,
 /********** RPCRunner **********/
 RPCRunner::RPCRunner(const String& key, const String& host, int port, int 
priority, int n_parallel,
                      int timeout, int number, int repeat, int min_repeat_ms,
-                     double cooldown_interval, bool enable_cpu_cache_flush) {
+                     double cooldown_interval, bool enable_cpu_cache_flush, 
int device) {
   auto node = make_object<RPCRunnerNode>();
   node->key = key;
   node->host = host;
@@ -168,6 +169,7 @@ RPCRunner::RPCRunner(const String& key, const String& host, 
int port, int priori
   node->min_repeat_ms = min_repeat_ms;
   node->cooldown_interval = cooldown_interval;
   node->enable_cpu_cache_flush = enable_cpu_cache_flush;
+  node->device = device;
   data_ = std::move(node);
 }
 
@@ -176,7 +178,7 @@ Array<MeasureResult> RPCRunnerNode::Run(const 
Array<MeasureInput>& inputs,
   if (const auto* f = runtime::Registry::Get("auto_scheduler.rpc_runner.run")) 
{
     Array<MeasureResult> results =
         (*f)(inputs, build_results, key, host, port, priority, n_parallel, 
timeout, number, repeat,
-             min_repeat_ms, cooldown_interval, enable_cpu_cache_flush, 
verbose);
+             min_repeat_ms, cooldown_interval, enable_cpu_cache_flush, 
verbose, device);
     return results;
   } else {
     LOG(FATAL) << "auto_scheduler.rpc_runner.run is not registered. "
@@ -409,17 +411,17 @@ TVM_REGISTER_GLOBAL("auto_scheduler.LocalBuilder")
 
 TVM_REGISTER_GLOBAL("auto_scheduler.LocalRunner")
     .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms,
-                       double cooldown_interval, bool enable_cpu_cache_flush) {
+                       double cooldown_interval, bool enable_cpu_cache_flush, 
int device) {
       return LocalRunner(timeout, number, repeat, min_repeat_ms, 
cooldown_interval,
-                         enable_cpu_cache_flush);
+                         enable_cpu_cache_flush, device);
     });
 
 TVM_REGISTER_GLOBAL("auto_scheduler.RPCRunner")
     .set_body_typed([](const String& key, const String& host, int port, int 
priority,
                        int n_parallel, int timeout, int number, int repeat, 
int min_repeat_ms,
-                       double cooldown_interval, bool enable_cpu_cache_flush) {
+                       double cooldown_interval, bool enable_cpu_cache_flush, 
int device) {
       return RPCRunner(key, host, port, priority, n_parallel, timeout, number, 
repeat,
-                       min_repeat_ms, cooldown_interval, 
enable_cpu_cache_flush);
+                       min_repeat_ms, cooldown_interval, 
enable_cpu_cache_flush, device);
     });
 
 }  // namespace auto_scheduler
diff --git a/tests/python/relay/test_auto_scheduler_tuning.py 
b/tests/python/relay/test_auto_scheduler_tuning.py
index bbf3c48..1431824 100644
--- a/tests/python/relay/test_auto_scheduler_tuning.py
+++ b/tests/python/relay/test_auto_scheduler_tuning.py
@@ -36,7 +36,7 @@ def tune_network(network, target):
         log_file = fp.name
 
         # Tuning
-        measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60)
+        measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60, 
device=0)
         tuner = auto_scheduler.TaskScheduler(tasks, task_weights, callbacks=[])
         tune_option = auto_scheduler.TuningOptions(
             num_measure_trials=100,

Reply via email to