lygztq opened a new pull request #8555:
URL: https://github.com/apache/tvm/pull/8555


   ## Background and Motivation
   Currently, TVM uses `Any` to represent an unknown dimension when the input 
has a dynamic shape. When building the schedule for a Relay primitive function, 
the `AnyNode` will be converted to a `Var` named `"any_dim"`. However, an 
element in a shape array cannot be a negative number, and use `Var` to 
represent shape array elements will cause redundant boundary check because we 
cannot deduce the sign of a `Var`. For example, given a simple network with 
only a softmax operation like
   ```python
   import numpy as np
   import tvm
   from tvm import relay
   from time import time
   
   # actual input shape
   dim0 = 2
   dim1 = 10
   dim2 = 8
   dim3 = 16
   
   # relay var shapes
   v_dim0 = relay.Any()
   v_dim1 = relay.Any()
   v_dim2 = relay.Any()
   v_dim3 = relay.Any()
   
   # rt settings
   exec_mod = "vm"
   tgt = "cuda"
   dev = tvm.device(tgt)
   
   def get_mod():
       x = relay.var("x", shape=(v_dim0, v_dim1, v_dim2, v_dim3))
       y = relay.nn.softmax(x)
       mod = tvm.IRModule()
       mod["main"] = relay.Function([x], y)
       return mod
   
   mod = get_mod()
   ```
   The corresponding tir looks like (here we show a part of the whole function 
since it is too long)
   ```
   primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
     attr = {"global_symbol": "fused_nn_softmax", "tir.noalias": True}
     buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), 
float32, [any_dim: int32, any_dim_1: int32, any_dim_2: int32, any_dim_3: 
int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], 
type="auto"),
                placeholder: Buffer(placeholder_2: Pointer(float32), float32, 
[any_dim, any_dim_1, any_dim_2, any_dim_3], [stride_4: int32, stride_5: int32, 
stride_6: int32, stride_7: int32], type="auto")}
     buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: 
T_softmax_norm} {
     attr [T_softmax_maxelem: Pointer(global float32)] "storage_scope" = 
"global";
     allocate(T_softmax_maxelem, float32, [((any_dim*any_dim_1)*any_dim_2)]);
     attr [T_softmax_exp: Pointer(global float32)] "storage_scope" = "global";
     allocate(T_softmax_exp, float32, 
[(((any_dim*any_dim_1)*any_dim_2)*any_dim_3)]) {
       attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", 
"blockIdx.x")] "thread_extent" = floordiv((((any_dim*any_dim_1)*any_dim_2) + 
511), 512);
       attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", 
"threadIdx.x")] "thread_extent" = 512;
       if (blockIdx.x < floordiv(((any_dim*any_dim_1)*any_dim_2), 512)) {
         if (floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), 
any_dim_1) < any_dim) {
           if (floormod(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), 
any_dim_1) < any_dim_1) {
             if (floormod(((blockIdx.x*512) + threadIdx.x), any_dim_2) < 
any_dim_2) {
               if (floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2) < 
(any_dim*any_dim_1)) {
                 T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = 
-3.40282e+38f32
               }
             }
           }
         }
   ...
   ```
   If you look carefully, you will find some unnecessary if-conditions like 
`floormod(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), any_dim_1) < 
any_dim_1`. This is because `floormod` can have negative inputs. For example, 
`y < floodmod(x, y) <= 0` if `y < 0`. Since we do not know the sign of 
`any_dim_1` and the output of `floodmod` will be greater than `any_dim_1` if 
`any_dim_1 < 0`, a redundant if-condition will be added here.
   
   ## Solution
   I think currently we do not have the need for negative shape arrays. 
Therefore, perhaps using `SizeVar` instead of `Var` here is a better choice.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to