This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9daf3fe Fix AutoScheduler for anaconda python (#7387)
9daf3fe is described below
commit 9daf3fee71db91e1adae8410ab8c70846df764ed
Author: dlexplorer <[email protected]>
AuthorDate: Sun Feb 7 05:47:30 2021 +0300
Fix AutoScheduler for anaconda python (#7387)
In case of non cpython flavour of python, the task passed to measure process
should be serialized using pickle approach. The task includes workload
which is a list of Tensors. The list should be serialized and deserialized
as an atomic object.
---
python/tvm/auto_scheduler/workload_registry.py | 18 ++++++----
.../python/unittest/test_auto_scheduler_measure.py | 39 ++++++++++++++++++++++
2 files changed, 51 insertions(+), 6 deletions(-)
diff --git a/python/tvm/auto_scheduler/workload_registry.py
b/python/tvm/auto_scheduler/workload_registry.py
index 51ae64d..cd8f8c9 100644
--- a/python/tvm/auto_scheduler/workload_registry.py
+++ b/python/tvm/auto_scheduler/workload_registry.py
@@ -35,6 +35,7 @@ import pickle
import json
import tvm._ffi
+from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
from .utils import serialize_args, deserialize_args, get_func_name
logger = logging.getLogger("auto_scheduler")
@@ -216,13 +217,17 @@ def serialize_workload_registry_entry(workload_key):
global WORKLOAD_FUNC_REGISTRY
if workload_key in WORKLOAD_FUNC_REGISTRY:
- return (workload_key, WORKLOAD_FUNC_REGISTRY[workload_key])
+ sname = workload_key
+ else:
+ workload = json.loads(workload_key)
+ sname = workload[0]
- workload = json.loads(workload_key)
- name = workload[0]
- value = WORKLOAD_FUNC_REGISTRY[name]
+ svalue = WORKLOAD_FUNC_REGISTRY[sname]
+ if not callable(svalue):
+ # pylint: disable=assignment-from-no-return
+ svalue = SaveJSON(svalue)
- return name, value
+ return sname, svalue
def deserialize_workload_registry_entry(data):
@@ -239,7 +244,8 @@ def deserialize_workload_registry_entry(data):
name, value = data
if name not in WORKLOAD_FUNC_REGISTRY:
- WORKLOAD_FUNC_REGISTRY[name] = value
+ # pylint: disable=assignment-from-no-return
+ WORKLOAD_FUNC_REGISTRY[name] = LoadJSON(value)
def save_workload_func_registry(filename):
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py
b/tests/python/unittest/test_auto_scheduler_measure.py
index 041fb7e..cc9d7a4 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -24,8 +24,10 @@ from tvm import topi
from tvm import te, auto_scheduler
import tempfile
import tvm.testing
+import pickle
from test_auto_scheduler_common import matmul_auto_scheduler_test,
get_tiled_matmul
+from tvm.auto_scheduler import workload_registry
def record_common(dag, s):
@@ -255,6 +257,42 @@ def test_measure_local_builder_runner():
assert mress[0].error_no == 0
+def test_dag_measure_local_builder_runner():
+ if not tvm.testing.device_enabled("llvm"):
+ return
+
+ A = te.placeholder((512, 512), name="A")
+ B = te.placeholder((512, 512), name="B")
+ k = te.reduce_axis((0, 512), name="k")
+ C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j],
axis=[k]), name="C")
+ D = topi.nn.relu(C)
+ E = topi.nn.relu(D)
+
+ tensors = [A, B, E]
+ dag = auto_scheduler.ComputeDAG(tensors)
+ key = workload_registry.register_workload_tensors(dag.workload_key(),
tensors)
+ transfer_data = workload_registry.serialize_workload_registry_entry(key)
+ f_data = pickle.dumps(transfer_data)
+ f_new = pickle.loads(f_data)
+ del workload_registry.WORKLOAD_FUNC_REGISTRY[key]
+ workload_registry.deserialize_workload_registry_entry(f_new)
+
+ target = tvm.target.Target("llvm")
+ task = auto_scheduler.SearchTask(compute_dag=dag, workload_key=key,
target=target)
+
+ for enable_cpu_cache_flush in [True, False]:
+ minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
+ local_builder = auto_scheduler.LocalBuilder()
+ local_runner = auto_scheduler.LocalRunner(
+ timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+ )
+
+ bress = local_builder.build([minp])
+ assert bress[0].error_no == 0
+ mress = local_runner.run([minp], bress)
+ assert mress[0].error_no == 0
+
+
def test_measure_local_builder_rpc_runner():
if not tvm.testing.device_enabled("llvm"):
return
@@ -325,5 +363,6 @@ if __name__ == "__main__":
test_recover_measure_input()
test_workload_dis_factor()
test_measure_local_builder_runner()
+ test_dag_measure_local_builder_runner()
test_measure_local_builder_rpc_runner()
test_measure_target_host()