yzh119 opened a new issue #9953:
URL: https://github.com/apache/tvm/issues/9953
The current TIR syntax printer (introduced in #9680 ) fails when there are
dynamic shapes in the script:
python
```
@T.prim_func
def f(a: T.handle, b: T.handle, c: T.handle):
N = T.var("int32")
M = T.var("int32")
K = T.var("int32")
A = T.match_buffer(a, (N, K), "float32")
B = T.match_buffer(b, (K, M), "float32")
C = T.match_buffer(c, (N, M), "float32")
for i, j, k in T.grid(N, M, K):
with T.block("gemm"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
print(f.script())
```
### Expected behavior
The output script should be the same as input.
### Actual behavior
The `M, N, K` are used before declaration.
```python
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(N, K), "float32"], B: T.Buffer[(K, M), "float32"], C:
T.Buffer[(N, M), "float32"]) -> None:
K = T.var("int32")
M = T.var("int32")
N = T.var("int32")
# body
# with T.block("root")
for i, j, k in T.grid(N, M, K):
with T.block("gemm"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
```
The same case if I pass tensor shape as parameters:
```python
@T.prim_func
def f(a: T.handle, b: T.handle, c: T.handle, N: T.int32, M: T.int32, K:
T.int32):
A = T.match_buffer(a, (N, K), "float32")
B = T.match_buffer(b, (K, M), "float32")
C = T.match_buffer(c, (N, M), "float32")
for i, j, k in T.grid(N, M, K):
with T.block("gemm"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]