vinx13 opened a new issue, #13148:
URL: https://github.com/apache/tvm/issues/13148
TOPI batch matmul mistakenly mark the RHS tensor of batch matmul as layout
free placeholder when it is a variable. As a result, `RewriteLayout` is applied
to it and it can't be constant-folded in Relay. It results in workload mismatch
because the `meta_schedule_layout_transform` op is fused with other operators,
resulting a new workload that hasn't been tuned.
### Expected behavior
Successfully tune and compile the Relay function.
### Actual behavior
One workload is missing from tuning database.
```
src/relay/backend/te_compiler_cache.cc:544: Warning: Cannot find workload:
vm_mod_fused_transpose_meta_schedule_layout_transform
# from tvm.script import tir as T
@T.prim_func
def func(p0: T.Buffer[(12, 64, 197), "int8"], T_meta_schedule_layout_trans:
T.Buffer[(12, 64, 197), "int8"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
T_transpose = T.alloc_buffer([12, 197, 64], dtype="int8")
for i0, i1, i2 in T.grid(12, 197, 64):
with T.block("T_transpose"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(p0[ax0, ax2, ax1])
T.writes(T_transpose[ax0, ax1, ax2])
T_transpose[ax0, ax1, ax2] = p0[ax0, ax2, ax1]
for i0, i1, i2 in T.grid(12, 64, 197):
with T.block("T_meta_schedule_layout_trans"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(T_transpose[ax0, ax2, ax1])
T.writes(T_meta_schedule_layout_trans[ax0, ax1, ax2])
T_meta_schedule_layout_trans[ax0, ax1, ax2] = T_transpose[ax0,
ax2, ax1]
```
### Environment
TVM v0.11.dev 5eab64885ad4
### Steps to reproduce
```
import tvm
from tvm import relay
x = relay.var("x", shape=(12, 197, 64), dtype="int8")
y = relay.var("y", shape=[12, 64, 197], dtype="int8")
y1 = relay.transpose(y, [0, 2, 1])
mm = relay.nn.batch_matmul(x, y1, out_dtype="int32", transpose_b=True)
func = relay.Function([x, y], mm)
mod = tvm.ir.IRModule({"main": func})
import tvm.meta_schedule as ms
target = tvm.target.Target("aws/cpu/c5.12xlarge")
database = ms.relay_integration.tune_relay(mod, {}, target=target,
work_dir="./work_dir", max_trials_global=200)
lib = ms.relay_integration.compile_relay(database=database, mod=mod,
target=target, params={}, backend='vm')
```
cc @zxybazh @junrushao
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]