junrushao1994 opened a new issue, #11746:
URL: https://github.com/apache/tvm/issues/11746

   Adaptive pooling operations may bring in dynamism in loop extents, even if 
the tensor shapes are not dynamic. The script below shows an interesting 
example:
   
   ```python
   from tvm import te, topi
   
   # def @fused_nn.adaptive_avg_pool2d(%p0: Tensor[(1, 8, 8, 512), float32] /* 
ty=Tensor[(1, 8, 8, 512), float32] */, Primitive=1) -> Tensor[(1, 7, 7, 512), 
float32] {
   #   nn.adaptive_avg_pool2d(%p0, output_size=[7, 7], layout="NHWC") /* 
ty=Tensor[(1, 7, 7, 512), float32] */
   # }
   
   def main():
       A = te.placeholder((1, 8, 8, 512), "float32", name="A")
       B = topi.nn.adaptive_pool(A, (7, 7), "avg", "NHWC")
       func = te.create_prim_func([A, B])
       print(func.script())
   
   if __name__ == "__main__":
       main()
   ```
   
   The script produces the TIR below:
   
   ```python
   @T.prim_func
   def func(A: T.Buffer[(1, 8, 8, 512), "float32"], tensor: T.Buffer[(1, 7, 7, 
512), "float32"]) -> None:
       # function attr dict
       T.func_attr({"global_symbol": "main", "tir.noalias": True})
       ax1 = T.var("int32")   # <========== undefined variable
       ax2 = T.var("int32")   # <========== undefined variable
       # body
       # with T.block("root")
       tensor_1 = T.alloc_buffer([1, 7, 7, 512], dtype="float32")
       for i0, i1, i2, i3, i4, i5 in T.grid(1, 7, 7, 512, T.Select((ax1 + 1) % 
7 == 0, (ax1 * 8 + 8) // 7, (ax1 * 8 + 8) // 7 + 1) - ax1 * 8 // 7, 
T.Select((ax2 + 1) % 7 == 0, (ax2 * 8 + 8) // 7, (ax2 * 8 + 8) // 7 + 1) - ax2 
* 8 // 7):
           with T.block("tensor"):
               ax0, ax1_1, ax2_1, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, 
i1, i2, i3, i4, i5])
               T.reads(A[ax0, ax1_1 * 8 // 7 + rv0, ax2_1 * 8 // 7 + rv1, ax3])
               T.writes(tensor_1[ax0, ax1_1, ax2_1, ax3])
               with T.init():
                   tensor_1[ax0, ax1_1, ax2_1, ax3] = T.float32(0)
               tensor_1[ax0, ax1_1, ax2_1, ax3] = tensor_1[ax0, ax1_1, ax2_1, 
ax3] + A[ax0, ax1_1 * 8 // 7 + rv0, ax2_1 * 8 // 7 + rv1, ax3]
       for i0, i1, i2, i3 in T.grid(1, 7, 7, 512):
           with T.block("tensor_1"):
               ax0, ax1_2, ax2_2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
               T.reads(tensor_1[ax0, ax1_2, ax2_2, ax3])
               T.writes(tensor[ax0, ax1_2, ax2_2, ax3])
               tensor[ax0, ax1_2, ax2_2, ax3] = tensor_1[ax0, ax1_2, ax2_2, 
ax3] / (T.cast(T.Select((ax1_2 + 1) % 7 == 0, (ax1_2 * 8 + 8) // 7, (ax1_2 * 8 
+ 8) // 7 + 1) - ax1_2 * 8 // 7, "float32") * T.cast(T.Select((ax2_2 + 1) % 7 
== 0, (ax2_2 * 8 + 8) // 7, (ax2_2 * 8 + 8) // 7 + 1) - ax2_2 * 8 // 7, 
"float32"))
   ```
   
   As we could infer, `ax1` and `ax2` are two undefined variables in the IR, 
which makes the IR ill-formed. This subsequently breaks various analysis during 
default MetaSchedule auto-scheduling. Therefore, we might want to at least fix 
`te.create_prim_func` to generate a valid IR.
   
   An interesting side note, Ansor is able to somehow schedule the program, 
even given its assumption of shape being static. Example:
   
   ```python
   from tvm import IRModule, auto_scheduler, relay
   
   
   def main():
       a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32")
       b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC")
       mod = IRModule({"main": relay.Function([a], b)})
       tasks, task_weights = auto_scheduler.extract_tasks(
           mod,
           {},
           target="llvm",
       )
       for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
           print(f"==== Task {idx}: {task.desc}. FLOPs 
{task.compute_dag.flop_ct}. (weight {task_weight} key: {task.workload_key}) 
=====")
           print(task.compute_dag)
       tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
       tuner.tune(
           auto_scheduler.TuningOptions(
               num_measure_trials=10,
               measure_callbacks=[],
           )
       )
   
   if __name__ == "__main__":
       main()
   ```
   
   CC: @Hzfengsy @tqchen @comaniac 
   
   ### Environment
   
   Based on the latest HEAD: 
https://github.com/apache/tvm/commit/24010db6c0e90bc555f6d12e23381fa7b00cf25d
   
   ### Steps to reproduce
   
   Run the script above.
   
   This issue is reported by @Kathryn-cat on VGG-16.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to