areusch commented on a change in pull request #8113:
URL: https://github.com/apache/tvm/pull/8113#discussion_r638925859
##########
File path: python/tvm/autotvm/measure/measure_methods.py
##########
@@ -575,18 +596,22 @@ def run_through_rpc(
f_preproc=f_prepare,
)
- try:
- random_fill =
remote.get_function("tvm.contrib.random.random_fill")
- except AttributeError:
- raise AttributeError(
- "Please make sure USE_RANDOM is ON in the config.cmake "
"on the remote devices"
- )
- args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info]
- if "scatter" not in measure_input.task.name:
- # the index tensor of scatter op cannot be randomly initialized
- for arg in args:
- random_fill(arg)
- dev.sync()
+ if ref_input:
+ args = [nd.array(x, device=dev) for x in ref_input]
+ else:
+ try:
+ random_fill =
remote.get_function("tvm.contrib.random.random_fill")
+ except AttributeError:
+ raise AttributeError(
+ "Please make sure USE_RANDOM is ON in the config.cmake
"
+ "on the remote devices"
+ )
+ args = [nd.empty(x[0], x[1], dev) for x in
build_result.arg_info]
+ if "scatter" not in measure_input.task.name:
+ # the index tensor of scatter op cannot be randomly
initialized
+ for arg in args:
+ random_fill(arg)
+ dev.sync()
Review comment:
actually i was referring to the outer `if ref_input:`. but i think i'm
retracting my ask here, because i believe the point of `sync` is to ensure
`random_fill` has been materialized into memory.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]