This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9daf3fe  Fix AutoScheduler for anaconda python (#7387)
9daf3fe is described below

commit 9daf3fee71db91e1adae8410ab8c70846df764ed
Author: dlexplorer <[email protected]>
AuthorDate: Sun Feb 7 05:47:30 2021 +0300

    Fix AutoScheduler for anaconda python (#7387)
    
    In case of non cpython flavour of python, the task passed to measure process
    should be serialized using pickle approach. The task includes workload
    which is a list of Tensors. The list should be serialized and deserialized
    as an atomic object.
---
 python/tvm/auto_scheduler/workload_registry.py     | 18 ++++++----
 .../python/unittest/test_auto_scheduler_measure.py | 39 ++++++++++++++++++++++
 2 files changed, 51 insertions(+), 6 deletions(-)

diff --git a/python/tvm/auto_scheduler/workload_registry.py 
b/python/tvm/auto_scheduler/workload_registry.py
index 51ae64d..cd8f8c9 100644
--- a/python/tvm/auto_scheduler/workload_registry.py
+++ b/python/tvm/auto_scheduler/workload_registry.py
@@ -35,6 +35,7 @@ import pickle
 import json
 
 import tvm._ffi
+from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
 from .utils import serialize_args, deserialize_args, get_func_name
 
 logger = logging.getLogger("auto_scheduler")
@@ -216,13 +217,17 @@ def serialize_workload_registry_entry(workload_key):
     global WORKLOAD_FUNC_REGISTRY
 
     if workload_key in WORKLOAD_FUNC_REGISTRY:
-        return (workload_key, WORKLOAD_FUNC_REGISTRY[workload_key])
+        sname = workload_key
+    else:
+        workload = json.loads(workload_key)
+        sname = workload[0]
 
-    workload = json.loads(workload_key)
-    name = workload[0]
-    value = WORKLOAD_FUNC_REGISTRY[name]
+    svalue = WORKLOAD_FUNC_REGISTRY[sname]
+    if not callable(svalue):
+        # pylint: disable=assignment-from-no-return
+        svalue = SaveJSON(svalue)
 
-    return name, value
+    return sname, svalue
 
 
 def deserialize_workload_registry_entry(data):
@@ -239,7 +244,8 @@ def deserialize_workload_registry_entry(data):
 
     name, value = data
     if name not in WORKLOAD_FUNC_REGISTRY:
-        WORKLOAD_FUNC_REGISTRY[name] = value
+        # pylint: disable=assignment-from-no-return
+        WORKLOAD_FUNC_REGISTRY[name] = LoadJSON(value)
 
 
 def save_workload_func_registry(filename):
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py 
b/tests/python/unittest/test_auto_scheduler_measure.py
index 041fb7e..cc9d7a4 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -24,8 +24,10 @@ from tvm import topi
 from tvm import te, auto_scheduler
 import tempfile
 import tvm.testing
+import pickle
 
 from test_auto_scheduler_common import matmul_auto_scheduler_test, 
get_tiled_matmul
+from tvm.auto_scheduler import workload_registry
 
 
 def record_common(dag, s):
@@ -255,6 +257,42 @@ def test_measure_local_builder_runner():
         assert mress[0].error_no == 0
 
 
+def test_dag_measure_local_builder_runner():
+    if not tvm.testing.device_enabled("llvm"):
+        return
+
+    A = te.placeholder((512, 512), name="A")
+    B = te.placeholder((512, 512), name="B")
+    k = te.reduce_axis((0, 512), name="k")
+    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], 
axis=[k]), name="C")
+    D = topi.nn.relu(C)
+    E = topi.nn.relu(D)
+
+    tensors = [A, B, E]
+    dag = auto_scheduler.ComputeDAG(tensors)
+    key = workload_registry.register_workload_tensors(dag.workload_key(), 
tensors)
+    transfer_data = workload_registry.serialize_workload_registry_entry(key)
+    f_data = pickle.dumps(transfer_data)
+    f_new = pickle.loads(f_data)
+    del workload_registry.WORKLOAD_FUNC_REGISTRY[key]
+    workload_registry.deserialize_workload_registry_entry(f_new)
+
+    target = tvm.target.Target("llvm")
+    task = auto_scheduler.SearchTask(compute_dag=dag, workload_key=key, 
target=target)
+
+    for enable_cpu_cache_flush in [True, False]:
+        minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
+        local_builder = auto_scheduler.LocalBuilder()
+        local_runner = auto_scheduler.LocalRunner(
+            timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+        )
+
+        bress = local_builder.build([minp])
+        assert bress[0].error_no == 0
+        mress = local_runner.run([minp], bress)
+        assert mress[0].error_no == 0
+
+
 def test_measure_local_builder_rpc_runner():
     if not tvm.testing.device_enabled("llvm"):
         return
@@ -325,5 +363,6 @@ if __name__ == "__main__":
     test_recover_measure_input()
     test_workload_dis_factor()
     test_measure_local_builder_runner()
+    test_dag_measure_local_builder_runner()
     test_measure_local_builder_rpc_runner()
     test_measure_target_host()

Reply via email to