This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fd39122 [AutoScheduler] Enable schedule sharing in dispatch context
(#7344)
fd39122 is described below
commit fd391223c19bec454f488f8a976a0766fadb0db3
Author: Cody Yu <[email protected]>
AuthorDate: Wed Jan 27 14:54:43 2021 -0800
[AutoScheduler] Enable schedule sharing in dispatch context (#7344)
* [AutoScheduler] Enable schedule sharing in dispatch context
* Update python/tvm/auto_scheduler/dispatcher.py
---
python/tvm/auto_scheduler/dispatcher.py | 135 ++++++++++++++++-----
python/tvm/auto_scheduler/measure_record.py | 65 +---------
python/tvm/auto_scheduler/utils.py | 65 +++++++++-
.../python/unittest/test_auto_scheduler_measure.py | 18 +--
4 files changed, 178 insertions(+), 105 deletions(-)
diff --git a/python/tvm/auto_scheduler/dispatcher.py
b/python/tvm/auto_scheduler/dispatcher.py
index b0b98d8..f2d7536 100644
--- a/python/tvm/auto_scheduler/dispatcher.py
+++ b/python/tvm/auto_scheduler/dispatcher.py
@@ -30,6 +30,7 @@ import numpy as np
from tvm.tir.expr import FloatImm
from .measure_record import load_records
+from .utils import calc_workload_dis_factor, decode_workload_key
logger = logging.getLogger("auto_scheduler")
@@ -126,18 +127,53 @@ class ApplyHistoryBest(DispatchContext):
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. Otherwise, it is an
iterator.
n_lines: Optional[int]
- if it is not None, only load the first `n_lines` lines of log
+ if it is not None, only load the first `n_lines` lines of log.
+ include_compatible: bool
+ When set to True, compatible records will also be considered.
"""
- def __init__(self, records, n_lines=None):
+ def __init__(self, records, n_lines=None, include_compatible=False):
super(ApplyHistoryBest, self).__init__()
+ self.include_compatible = include_compatible
+ # Dict[str (target key),
+ # Dict[str (workload hash),
+ # Dict[tuple (workload args), tuple (State, cost)]]]
self.best_by_targetkey = {}
self.best_by_model = {}
self._best_user_defined = {}
self.load(records, n_lines)
+ @staticmethod
+ def get_workload_entry(best_records, target_key, workload_key):
+ """Get the entry of the target key and workload key hash in the given
best record map.
+
+ Parameters
+ ----------
+ best_records: Dict[str, Dict[str, Dict[str, Any]]]
+ The best record map.
+ target_key: str
+ The first key to the best_records.
+ workload_key: str
+ The workload key that can be decoded to workload hash and args.
+
+ Returns
+ -------
+ entry: Dict[str, Any]
+ The entry in best_records with target key and workload hash.
+ workload_hash: str
+ The workload hash decoded from workload_key.
+ workload_args: Tuple[Any, ...]
+ The hashable tuple of workload args decoded from workload_key.
+ """
+ workload_hash, workload_args = decode_workload_key(workload_key)
+ if target_key not in best_records:
+ best_records[target_key] = {}
+ if workload_hash not in best_records[target_key]:
+ best_records[target_key][workload_hash] = {}
+ return best_records[target_key][workload_hash], workload_hash,
workload_args
+
def load(self, records, n_lines=None):
"""Load records to this dispatch context
@@ -171,29 +207,32 @@ class ApplyHistoryBest(DispatchContext):
if res.error_no != 0:
continue
+ costs = [x.value for x in res.costs if isinstance(x, FloatImm)]
+ cost = np.mean(costs)
+
# use target keys in tvm target system as key to build best map
for k in inp.task.target.keys:
- key = (k, inp.task.workload_key)
- if key not in best_by_targetkey:
- best_by_targetkey[key] = (inp, res)
+ entry, _, workload_args = self.get_workload_entry(
+ best_by_targetkey, k, inp.task.workload_key
+ )
+ if workload_args not in entry:
+ entry[workload_args] = (inp.state, cost)
else:
- _, other_res = best_by_targetkey[key]
- other_costs = [x.value for x in other_res.costs if
isinstance(x, FloatImm)]
- costs = [x.value for x in res.costs if isinstance(x,
FloatImm)]
- if np.mean(other_costs) > np.mean(costs):
- best_by_targetkey[key] = (inp, res)
+ _, other_cost = entry[workload_args]
+ if other_cost > cost:
+ entry[workload_args] = (inp.state, cost)
# use model as key to build best map
- key = (inp.task.target.model, inp.task.workload_key)
- if key not in best_by_model:
+ entry, _, workload_args = self.get_workload_entry(
+ best_by_model, inp.task.target.model, inp.task.workload_key
+ )
+ if workload_args not in entry:
if inp.task.target.model != "unknown":
- best_by_model[key] = (inp, res)
+ entry[workload_args] = (inp.state, cost)
else:
- _, other_res = best_by_model[key]
- other_costs = [x.value for x in other_res.costs if
isinstance(x, FloatImm)]
- costs = [x.value for x in res.costs if isinstance(x, FloatImm)]
- if np.mean(other_costs) > np.mean(costs):
- best_by_model[key] = (inp, res)
+ _, other_cost = entry[workload_args]
+ if other_cost > cost:
+ entry[workload_args] = (inp.state, cost)
logger.debug("Finish loading %d records", counter)
@@ -205,31 +244,61 @@ class ApplyHistoryBest(DispatchContext):
" above the dispatcher call. So does other target. "
)
+ def match_record(best_records, target_key, workload_key):
+ """The helper function to match the record in the given map
+ and return the matched state, or None if no match.
+ """
+ ret = None
+
+ entry, workload_hash, workload_args = self.get_workload_entry(
+ best_records, target_key, workload_key
+ )
+ if workload_args in entry:
+ ret = entry[workload_args][0]
+ elif self.include_compatible:
+ best_cost = float("inf")
+ for args, val in entry.items():
+ dis_f = calc_workload_dis_factor(
+ (workload_hash, workload_args), (workload_hash, args)
+ )
+ if dis_f == float("inf"):
+ continue
+
+ state, cost = val
+ cost *= dis_f
+ if ret is None or cost < best_cost:
+ best_cost = cost
+ ret = state
+ return ret
+
# first try matching by model
- key = (target.model, workload_key)
- if key in self._best_user_defined:
- return self._best_user_defined[key]
- if key in self.best_by_model:
- return self.best_by_model[key][0].state
+ ret = match_record(self._best_user_defined, target.model, workload_key)
+ if ret is not None:
+ return ret
+ ret = match_record(self.best_by_model, target.model, workload_key)
+ if ret is not None:
+ return ret
# then try matching by target key
for k in target.keys:
- key = (k, workload_key)
- if key in self._best_user_defined:
- return self._best_user_defined[key]
- if key in self.best_by_targetkey:
- return self.best_by_targetkey[key][0].state
+ ret = match_record(self._best_user_defined, k, workload_key)
+ if ret is not None:
+ return ret
+ ret = match_record(self.best_by_targetkey, k, workload_key)
+ if ret is not None:
+ return ret
return None
def update(self, target, workload_key, state):
- model = target.model
- key = (model, workload_key)
- self._best_user_defined[key] = state
+ entry, _, workload_args = self.get_workload_entry(
+ self._best_user_defined, target.model, workload_key
+ )
+ entry[workload_args] = (state, 1)
for k in target.keys:
- key = (k, workload_key)
- self._best_user_defined[key] = state
+ entry, _, _ = self.get_workload_entry(self._best_user_defined, k,
workload_key)
+ entry[workload_args] = (state, 1)
class FallbackContext(DispatchContext):
diff --git a/python/tvm/auto_scheduler/measure_record.py
b/python/tvm/auto_scheduler/measure_record.py
index 9eaef18..200d24f 100644
--- a/python/tvm/auto_scheduler/measure_record.py
+++ b/python/tvm/auto_scheduler/measure_record.py
@@ -27,7 +27,7 @@ import numpy as np
import tvm._ffi
from tvm.runtime import Object
from .measure import MeasureErrorNo, MeasureCallback
-from .utils import decode_workload_key
+from .utils import calc_workload_dis_factor, decode_workload_key
from . import _ffi_api
logger = logging.getLogger("auto_scheduler")
@@ -130,65 +130,6 @@ class RecordReader(Object):
yield ret[0], ret[1] # (input, result)
-def calc_workload_dis_factor(target_workload_key, workload_key):
- """Calculate the distance factor of the workload to the target workload.
- If two workloads are not compatible at all (i.e., different compute DAG or
function),
- then the distance factor is "inf". Otherwise, we calculate the factor by
traversing
- the workload arguments, which are the arguments of the compute function,
- or the output shapes for the ComputeDAG. The factor is calculated by the
following rules:
-
- 1. For non-zero integer values: `product(target_arg / candidate_arg)`.
- 2. For non-integer or zero values: "inf" if not equal else 1.
-
- As a result, factor=1 is the optimal when two workloads are identical.
-
- Parameters
- ----------
- target_workload_key: str
- The target workload key in JSON string.
-
- workload_key: str
- The candidate workload key in JSON string.
-
- Returns
- -------
- dis_f: float
- The distance factor.
- """
-
- def flatten_list(inp):
- ret = []
- for elt in inp:
- if isinstance(elt, list):
- ret += flatten_list(elt)
- else:
- ret.append(elt)
- return ret
-
- target_key, target_args = decode_workload_key(target_workload_key)
- target_args = flatten_list(target_args) if target_args is not None else []
- key, args = decode_workload_key(workload_key)
- args = flatten_list(args) if args is not None else []
-
- # Not even the same func/DAG.
- if key != target_key or len(target_args) != len(args):
- return float("inf")
-
- dis_f = 1
- for target_arg, arg in zip(target_args, args):
- if isinstance(target_arg, int):
- if target_arg == 0 or arg == 0:
- if target_arg != arg:
- return float("inf")
- elif target_arg % arg != 0:
- return float("inf")
- else:
- dis_f *= target_arg / arg
- elif target_arg != arg:
- return float("inf")
- return dis_f
-
-
def load_record_from_string(record):
"""
Load the measure record from string.
@@ -304,7 +245,9 @@ def load_best_record(filename, workload_key=None,
target=None, include_compatibl
cost = np.mean(costs)
if workload_key is not None:
- dis_f = calc_workload_dis_factor(workload_key,
inp.task.workload_key)
+ dis_f = calc_workload_dis_factor(
+ decode_workload_key(workload_key),
decode_workload_key(inp.task.workload_key)
+ )
if dis_f == float("inf"):
continue
if not include_compatible and dis_f != 1:
diff --git a/python/tvm/auto_scheduler/utils.py
b/python/tvm/auto_scheduler/utils.py
index fd25fdb..8aa33e6 100644
--- a/python/tvm/auto_scheduler/utils.py
+++ b/python/tvm/auto_scheduler/utils.py
@@ -57,18 +57,77 @@ def decode_workload_key(workload_key):
-------
name: str
The workload function name or the DAG hash.
- args: Optional[List[Any]]
- The arguments of the workload, or None if the workload key format is
not decodeable.
+ args: Optional[Tuple[Any, ...]]
+ The flatten arguments in a tuple, or None if the workload key format
is not decodeable.
"""
+
+ def flatten_list(inp):
+ ret = []
+ for elt in inp:
+ if isinstance(elt, list):
+ ret += flatten_list(elt)
+ else:
+ ret.append(elt)
+ return ret
+
try:
key_list = json.loads(workload_key)
if isinstance(key_list, list) and len(key_list) >= 1:
- return key_list[0], key_list[1:]
+ return key_list[0], tuple(flatten_list(key_list[1:]))
except json.decoder.JSONDecodeError:
pass
return workload_key, None
+def calc_workload_dis_factor(target_workload_pair, workload_pair):
+ """Calculate the distance factor of the workload to the target workload.
+ If two workloads are not compatible at all (i.e., different compute DAG or
function),
+ then the distance factor is "inf". Otherwise, we calculate the factor by
traversing
+ the workload arguments, which are the arguments of the compute function,
+ or the output shapes for the ComputeDAG. The factor is calculated by the
following rules:
+
+ 1. For non-zero integer values: `product(target_arg / candidate_arg)`.
+ 2. For non-integer or zero values: "inf" if not equal else 1.
+
+ As a result, factor=1 is the optimal when two workloads are identical.
+
+ Parameters
+ ----------
+ target_workload_pair: Tuple[str, Optional[Tuple[Any, ...]]]
+ The target workload pair: (hash, argument tuple).
+
+ workload_pair: Tuple[str, Optional[Tuple[Any, ...]]]
+ The candidate workload pair: (hash, argument tuple).
+
+ Returns
+ -------
+ dis_f: float
+ The distance factor.
+ """
+ target_key, target_args = target_workload_pair
+ target_args = target_args if target_args is not None else []
+ key, args = workload_pair
+ args = args if args is not None else []
+
+ # Not even the same func/DAG.
+ if key != target_key or len(target_args) != len(args):
+ return float("inf")
+
+ dis_f = 1
+ for target_arg, arg in zip(target_args, args):
+ if isinstance(target_arg, int):
+ if target_arg == 0 or arg == 0:
+ if target_arg != arg:
+ return float("inf")
+ elif target_arg % arg != 0:
+ return float("inf")
+ else:
+ dis_f *= target_arg / arg
+ elif target_arg != arg:
+ return float("inf")
+ return dis_f
+
+
def get_func_name(func):
"""Get name of a function.
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py
b/tests/python/unittest/test_auto_scheduler_measure.py
index 3b074b2..041fb7e 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -202,35 +202,36 @@ def test_recover_measure_input():
def test_workload_dis_factor():
- calc = auto_scheduler.measure_record.calc_workload_dis_factor
+ calc = auto_scheduler.utils.calc_workload_dis_factor
+ decode = auto_scheduler.utils.decode_workload_key
# Identical
target_wkl_key = json.dumps(
["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "float32"]
)
- assert calc(target_wkl_key, target_wkl_key) == 1
+ assert calc(decode(target_wkl_key), decode(target_wkl_key)) == 1
# Compatible with a factor
wkl_key = json.dumps(["func1", [1, 3, 112, 112], [32, 3, 3, 3], [0, 0],
[1, 1], "float32"])
- assert calc(target_wkl_key, wkl_key) == 8 * 2 * 2
+ assert calc(decode(target_wkl_key), decode(wkl_key)) == 8 * 2 * 2
# Incompatible argument with zeros
wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [1, 1],
[1, 1], "float32"])
- assert calc(target_wkl_key, wkl_key) == float("inf")
+ assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf")
wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0],
[0, 0], "float32"])
- assert calc(target_wkl_key, wkl_key) == float("inf")
+ assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf")
# Incompatible non-integter argument
wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0],
[1, 1], "int8"])
- assert calc(target_wkl_key, wkl_key) == float("inf")
+ assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf")
# Incompatible function
wkl_key = json.dumps(["func2", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0],
[1, 1], "float32"])
- assert calc(target_wkl_key, wkl_key) == float("inf")
+ assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf")
# Incompatible due to non-dividable factor
wkl_key = json.dumps(["func1", [8, 3, 223, 223], [32, 3, 3, 3], [0, 0],
[1, 1], "float32"])
- assert calc(target_wkl_key, wkl_key) == float("inf")
+ assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf")
def test_measure_local_builder_runner():
@@ -322,6 +323,7 @@ if __name__ == "__main__":
test_record_follow_split_follow_fused_split()
test_record_pragma_storage_align_rfactor()
test_recover_measure_input()
+ test_workload_dis_factor()
test_measure_local_builder_runner()
test_measure_local_builder_rpc_runner()
test_measure_target_host()