vinx13 opened a new issue, #13148:
URL: https://github.com/apache/tvm/issues/13148

   TOPI batch matmul mistakenly mark the RHS tensor of batch matmul as layout 
free placeholder when it is a variable. As a result, `RewriteLayout` is applied 
to it and it can't be constant-folded in Relay. It results in workload mismatch 
because the `meta_schedule_layout_transform` op is fused with other operators, 
resulting a new workload that hasn't been tuned.
   
   ### Expected behavior
   
   Successfully tune and compile the Relay function.
   
   ### Actual behavior
   
   One workload is missing from tuning database.
   
   ```
   src/relay/backend/te_compiler_cache.cc:544: Warning: Cannot find workload: 
vm_mod_fused_transpose_meta_schedule_layout_transform
   # from tvm.script import tir as T
   @T.prim_func
   def func(p0: T.Buffer[(12, 64, 197), "int8"], T_meta_schedule_layout_trans: 
T.Buffer[(12, 64, 197), "int8"]) -> None:
       # function attr dict
       T.func_attr({"global_symbol": "main", "tir.noalias": True})
       # body
       # with T.block("root")
       T_transpose = T.alloc_buffer([12, 197, 64], dtype="int8")
       for i0, i1, i2 in T.grid(12, 197, 64):
           with T.block("T_transpose"):
               ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
               T.reads(p0[ax0, ax2, ax1])
               T.writes(T_transpose[ax0, ax1, ax2])
               T_transpose[ax0, ax1, ax2] = p0[ax0, ax2, ax1]
       for i0, i1, i2 in T.grid(12, 64, 197):
           with T.block("T_meta_schedule_layout_trans"):
               ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
               T.reads(T_transpose[ax0, ax2, ax1])
               T.writes(T_meta_schedule_layout_trans[ax0, ax1, ax2])
               T_meta_schedule_layout_trans[ax0, ax1, ax2] = T_transpose[ax0, 
ax2, ax1]
   ```
   
   ### Environment
   
   TVM v0.11.dev 5eab64885ad4
   
   ### Steps to reproduce
   
   ```
   import tvm
   from tvm import relay
   x = relay.var("x", shape=(12, 197, 64), dtype="int8")
   y = relay.var("y", shape=[12, 64, 197], dtype="int8")
   y1 = relay.transpose(y, [0, 2, 1])
   mm = relay.nn.batch_matmul(x, y1, out_dtype="int32", transpose_b=True)
   
   func = relay.Function([x, y], mm)
   mod = tvm.ir.IRModule({"main": func})
   
   import tvm.meta_schedule as ms
   target = tvm.target.Target("aws/cpu/c5.12xlarge")
   database = ms.relay_integration.tune_relay(mod, {}, target=target, 
work_dir="./work_dir", max_trials_global=200)
   lib = ms.relay_integration.compile_relay(database=database, mod=mod, 
target=target, params={}, backend='vm')
   ```
   
   cc @zxybazh @junrushao 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to