merrymercy commented on a change in pull request #7053:
URL: https://github.com/apache/tvm/pull/7053#discussion_r538809651
##########
File path: python/tvm/auto_scheduler/auto_schedule.py
##########
@@ -210,6 +253,35 @@ def auto_schedule(task, search_policy=None,
tuning_options=TuningOptions()):
if search_policy is None:
cost_model = XGBModel()
search_policy = SketchPolicy(task, cost_model)
+
+ if tuning_options.check_correctness == True:
+ empty_sch, args = task.compute_dag.apply_steps_from_state(
+ task.compute_dag.get_init_state(), layout_rewrite=True)
Review comment:
```suggestion
task.compute_dag.get_init_state())
```
##########
File path: python/tvm/auto_scheduler/auto_schedule.py
##########
@@ -210,6 +253,35 @@ def auto_schedule(task, search_policy=None,
tuning_options=TuningOptions()):
if search_policy is None:
cost_model = XGBModel()
search_policy = SketchPolicy(task, cost_model)
+
+ if tuning_options.check_correctness == True:
+ empty_sch, args = task.compute_dag.apply_steps_from_state(
+ task.compute_dag.get_init_state(), layout_rewrite=True)
+ cpu_func = build_module.build(
+ empty_sch, args, target="llvm",
target_host=task.target_host
Review comment:
```suggestion
empty_sch, args, target="llvm",
```
This function runs on the host machine. So we should not use `target_host`.
##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -889,18 +929,43 @@ def _timed_rpc_run(
if error_no == 0:
try:
- args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for
x in build_res.args]
- try:
- random_fill =
remote.get_function("tvm.contrib.random.random_fill")
- except AttributeError:
- raise AttributeError(
- "Please make sure USE_RANDOM is ON in the config.cmake "
"on the remote devices"
- )
- for arg in args:
- random_fill(arg)
+ if os.path.exists(working_dir):
+ buffer_path = os.path.join(working_dir, "buffer.pkl")
+ if os.path.exists(buffer_path):
+ with open(buffer_path, "rb") as fi:
+ buffer = pickle.load(fi)
+ # force last args to be empty
+ args = []
+ for i in range(len(build_res.args) - 1):
+
args.append(ndarray.array(buffer[build_res.args[i].name], ctx=ctx))
Review comment:
This will cause a copy over RPC for every measurement. A better way is
to copy the `buffer.pkl` to the remote device once and do the
initialization/check on the remote device.
We can leave a "todo" here and leave this optimization to future PRs.
##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -275,13 +278,16 @@ class LocalBuilder(ProgramBuilder):
timeout : int = 15
The timeout limit (in second) for each build thread.
This is used in a wrapper of the multiprocessing.Process.join().
- n_parallel : int = multiprocessing.cpu_count()
+ n_parallel : int = -1
Review comment:
```suggestion
n_parallel : Optional[int] = None
```
Use None as default
##########
File path: python/tvm/auto_scheduler/auto_schedule.py
##########
@@ -210,6 +253,35 @@ def auto_schedule(task, search_policy=None,
tuning_options=TuningOptions()):
if search_policy is None:
cost_model = XGBModel()
search_policy = SketchPolicy(task, cost_model)
+
+ if tuning_options.check_correctness == True:
+ empty_sch, args = task.compute_dag.apply_steps_from_state(
+ task.compute_dag.get_init_state(), layout_rewrite=True)
+ cpu_func = build_module.build(
+ empty_sch, args, target="llvm",
target_host=task.target_host
+ )
+ buffer_path = os.path.join(tuning_options.working_dir, "buffer.pkl")
+ if os.path.exists(buffer_path) is True:
+ with open(buffer_path, "rb") as fi:
+ buffer = pickle.load(fi)
+ if len(buffer) == len(args):
+ # we skip check each arg shape here
+ pass
+ elif len(buffer) == len(args) - 1:
+ # assume only one output
+ np_args =
np.zeros(size=get_const_tuple(args[-1].shape)).astype(args[-1].dtype)
+ cpu_args = [v for _, v in buffer.items()] +
[ndarray.array(np_args, ctx=tvm.cpu())]
Review comment:
```suggestion
cpu_args = [v for _, v in buffer.items()] +
[ndarray.empty(shape=get_const_tuple(args[-1].shape), dtype=args[-1].dtype,
ctx=tvm.cpu())]
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]