Hzfengsy opened a new pull request #9691:
URL: https://github.com/apache/tvm/pull/9691


   Current `te.create _prim_func` will not simplify the loop extent and other 
PrimExpr. which will cause problem when we have `Select` in computing, such as 
`topi.nn.adaptive_pool`
   
   Previously, the generated codes are here, with `T.Select` and two undefined 
vars.
   ```
   # from tvm.script import tir as T
   @T.prim_func
   def func(var_placeholder: T.handle, var_tensor: T.handle) -> None:
       ax2 = T.var("int32")
       ax3 = T.var("int32")
       placeholder = T.match_buffer(var_placeholder, [1, 128, 10, 10, 4], 
dtype="float32")
       tensor = T.match_buffer(var_tensor, [1, 128, 1, 1, 4], dtype="float32")
       # body
       # with T.block("root")
       tensor_1 = T.alloc_buffer([1, 128, 1, 1, 4], dtype="float32")
       for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 128, 1, 1, 4, T.Select(True, 
(ax2 + 1) * 10, (ax2 + 1) * 10 + 1) - ax2 * 10, T.Select(True, (ax3 + 1) * 10, 
(ax3 + 1) * 10 + 1) - ax3 * 10):
           with T.block("tensor"):
               ax0, ax1, ax2_1, ax3_1, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", 
[i0, i1, i2, i3, i4, i5, i6])
               T.reads([tensor_1[ax0, ax1, ax2_1, ax3_1, ax4], placeholder[ax0, 
ax1, ax2_1 * 10 + rv0, ax3_1 * 10 + rv1, ax4]])
               T.writes([tensor_1[ax0, ax1, ax2_1, ax3_1, ax4]])
               with T.init():
                   tensor_1[ax0, ax1, ax2_1, ax3_1, ax4] = T.float32(0)
               tensor_1[ax0, ax1, ax2_1, ax3_1, ax4] = tensor_1[ax0, ax1, 
ax2_1, ax3_1, ax4] + placeholder[ax0, ax1, ax2_1 * 10 + rv0, ax3_1 * 10 + rv1, 
ax4]
       for i0, i1, i2, i3, i4 in T.grid(1, 128, 1, 1, 4):
           with T.block("tensor_1"):
               ax0, ax1, ax2_2, ax3_2, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, 
i3, i4])
               T.reads([tensor_1[ax0, ax1, ax2_2, ax3_2, ax4]])
               T.writes([tensor[ax0, ax1, ax2_2, ax3_2, ax4]])
               tensor[ax0, ax1, ax2_2, ax3_2, ax4] = tensor_1[ax0, ax1, ax2_2, 
ax3_2, ax4] / (T.cast(T.Select(True, (ax2_2 + 1) * 10, (ax2_2 + 1) * 10 + 1) - 
ax2_2 * 10, "float32") * T.cast(T.Select(True, (ax3_2 + 1) * 10, (ax3_2 + 1) * 
10 + 1) - ax3_2 * 10, "float32"))
   ```
   
   After the fix, it's:
   ```
   # from tvm.script import tir as T
   @T.prim_func
   def func(var_placeholder: T.handle, var_tensor: T.handle) -> None:
       placeholder = T.match_buffer(var_placeholder, [1, 128, 10, 10, 4], 
dtype="float32")
       tensor = T.match_buffer(var_tensor, [1, 128, 1, 1, 4], dtype="float32")
       # body
       # with T.block("root")
       tensor_1 = T.alloc_buffer([1, 128, 1, 1, 4], dtype="float32")
       for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 128, 1, 1, 4, 10, 10):
           with T.block("tensor"):
               ax0, ax1, ax2, ax3, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", [i0, 
i1, i2, i3, i4, i5, i6])
               T.reads([tensor_1[ax0, ax1, ax2, ax3, ax4], placeholder[ax0, 
ax1, ax2 * 10 + rv0, ax3 * 10 + rv1, ax4]])
               T.writes([tensor_1[ax0, ax1, ax2, ax3, ax4]])
               with T.init():
                   tensor_1[ax0, ax1, ax2, ax3, ax4] = T.float32(0)
               tensor_1[ax0, ax1, ax2, ax3, ax4] = tensor_1[ax0, ax1, ax2, ax3, 
ax4] + placeholder[ax0, ax1, ax2 * 10 + rv0, ax3 * 10 + rv1, ax4]
       for i0, i1, i2, i3, i4 in T.grid(1, 128, 1, 1, 4):
           with T.block("tensor_1"):
               ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, 
i4])
               T.reads([tensor_1[ax0, ax1, ax2, ax3, ax4]])
               T.writes([tensor[ax0, ax1, ax2, ax3, ax4]])
               tensor[ax0, ax1, ax2, ax3, ax4] = tensor_1[ax0, ax1, ax2, ax3, 
ax4] * T.float32(0.01)
   ```
   
   cc @junrushao1994 @zxybazh 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to