comaniac opened a new pull request #7317:
URL: https://github.com/apache/tvm/pull/7317


   In this PR, we attempt to enable schedule sharing as a workaround before the 
dynamic shape support is fully landed. The idea is that if we have a schedule 
for batch size 1, then it is actually applicable to all other batch sizes 
(regardless the performance). This is useful when we only tune the workload 
with batch size 1 but wish to use it for all batch sizes to at least make the 
flow working.
   
   To do so, we introduce "workload distance factor", which indicates the 
similarity of two workloads. Specifically, it is calculated by the following 
rules:
   - If two workloads are not for the same compute DAG or function, then `inf`.
   - If two workloads are for the same compute DAG/function, and
       - their non-zero integer arguments are dividable and their zero and 
non-integer arguments are the same, then `factor=prod(a / b) for a, b in 
zip(wkl1.args, wkl2.args)`,
       - otherwise `inf`.
   
   As a result, the distance factor ranges from 1 to `inf`. When the distance 
factor is not `inf`, meaning that it is safe to apply the schedule of workload 
2 to workload 1.
   
   The above mechanism works well for registered TE computes but not the 
ComputeDAG extracted from Relay programs. This is because currently when 
extracting tasks from Relay, we use MD5 to hash the ComputeDAG serialized 
string to be its key, which includes not only the DAG structure but the shapes, 
so it's impossible to calculate the distance factor. To make it work, this PR 
also improves the hashing mechanism of ComputeDAG by separating the 
input/output tensor shapes so that they can be accessed. For example, the 
workload key of a ComputeDAG was:
   
   ```
   ["8d5a93959138dc7b2ee1f1b3219dfa14"]
   ```
   
   and it now becomes:
   
   ```
   ["ad6cecbf5d85cb1cda3c2bb7af170211", 1, 7, 7, 512, 4, 4, 512, 512, 1, 7, 7, 
512, 1, 1, 1, 512, 1, 1, 1, 512, 1, 7, 7, 512]
   ```
   
   Please note that since we change the workload key format of ComputeDAG, the 
tuning logs won't match anymore. To make it work again, we can use the 
following script to update the keys in existing log files. This is also the way 
I used to update the CI logs:
   
   ```python
   import json
   import hashlib
   import os
   import sys
   
   from tvm.te import ComputeOp, PlaceholderOp
   
   from tvm.auto_scheduler import save_records
   from tvm.auto_scheduler.measure import MeasureInput
   from tvm.auto_scheduler.measure_record import load_records
   from tvm.auto_scheduler.utils import get_const_tuple
   
   tasks = [] # Extract tasks from a Relay program
   log_file = "old-log-file"
   new_log_file = "new-log-file"
   
   def get_old_hash_key(dag):
       """Return the hash key of a compute DAG."""
       str_key = ""
       for op in dag.ops:
           t = op.output(0)
           if isinstance(op, PlaceholderOp):
               str_key += "placeholder,"
               str_key += str(get_const_tuple(t.shape)) + ","
               str_key += t.dtype + ";"
           elif isinstance(op, ComputeOp):
               str_key += str(t.op.body) + ","
               str_key += str(get_const_tuple(t.shape)) + ","
               str_key += t.dtype + ";"
           else:
               raise ValueError("Invalid op: " + op)
   
       str_key = str_key.encode(encoding="utf-8")
       return hashlib.md5(str_key).hexdigest()
   
   
   # Establish the key mapping
   old_key_to_task = {}
   hit_count = {}
   for idx, task in enumerate(tasks):
       old_key = json.dumps((get_old_hash_key(task.compute_dag),))
       old_key_to_task[old_key] = task
       hit_count[old_key] = 0
       print("Task %d %s -> %s" % (idx, old_key, task.workload_key))
   
   
   # Update the workload key in an existing log file
   new_inputs = []
   new_results = []
   for inp, res in load_records(log_file):
       if inp.task.workload_key not in old_key_to_task:
           print(
               "Ignore key %s in log file due to no corresponding task found" % 
inp.task.workload_key
           )
           continue
       hit_count[inp.task.workload_key] += 1
       new_inputs.append(MeasureInput(old_key_to_task[inp.task.workload_key], 
inp.state))
       new_results.append(res)
   
   for key, cnt in hit_count.items():
       print("Old key %s hits %d times" % (key, cnt))
   
   if os.path.exists(new_log_file):
       os.remove(new_log_file)
   save_records(new_log_file, new_inputs, new_results)
   
   ```
   
   cc @merrymercy @jcf94 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to