yzh119 opened a new issue #9953:
URL: https://github.com/apache/tvm/issues/9953


   The current TIR syntax printer (introduced in #9680 ) fails when there are 
dynamic shapes in the script:
   
   python
   ```
   @T.prim_func
   def f(a: T.handle, b: T.handle, c: T.handle):
       N = T.var("int32")
       M = T.var("int32")
       K = T.var("int32")
       A = T.match_buffer(a, (N, K), "float32")
       B = T.match_buffer(b, (K, M), "float32")
       C = T.match_buffer(c, (N, M), "float32")
       for i, j, k in T.grid(N, M, K):
           with T.block("gemm"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   C[vi, vj] = 0.
               C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
   
   print(f.script())
   ```
   
   ### Expected behavior
   The output script should be the same as input.
   
   ### Actual behavior
   
   The `M, N, K` are used before declaration.
   ```python
   # from tvm.script import tir as T
   @T.prim_func
   def func(A: T.Buffer[(N, K), "float32"], B: T.Buffer[(K, M), "float32"], C: 
T.Buffer[(N, M), "float32"]) -> None:
       K = T.var("int32")
       M = T.var("int32")
       N = T.var("int32")
       # body
       # with T.block("root")
       for i, j, k in T.grid(N, M, K):
           with T.block("gemm"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
               T.writes(C[vi, vj])
               with T.init():
                   C[vi, vj] = T.float32(0)
               C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
   ```
   
   The same case if I pass tensor shape as parameters:
   ```python
   @T.prim_func
   def f(a: T.handle, b: T.handle, c: T.handle, N: T.int32, M: T.int32, K: 
T.int32):
       A = T.match_buffer(a, (N, K), "float32")
       B = T.match_buffer(b, (K, M), "float32")
       C = T.match_buffer(c, (N, M), "float32")
       for i, j, k in T.grid(N, M, K):
           with T.block("gemm"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   C[vi, vj] = 0.
               C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to