SauryLi opened a new pull request #8618:
URL: https://github.com/apache/tvm/pull/8618


   When I use relay VM to compile and run a little while-loop TF network on 
cuda target, nvprof shows quadruple fused_less_less_logical_or_min_kernels has 
been launched.
   
   Example
   
   ```python
   with tf.Session() as sess:
       a = tf.placeholder(shape=(2, None), name="a", dtype=tf.float32)
       b = tf.placeholder(shape=(2, None), name="b", dtype=tf.float32)
       c = tf.placeholder(dtype=tf.float32, name="c")
       def cond(a, b, c):
         return tf.less(tf.reduce_sum(a), c) | tf.less(tf.reduce_sum(b), c)
   
       def body(a, b, c):
         a += 1
         b += 1
         return a, b, c
   
       new_a, new_b, _ = tf.while_loop(cond, body, [a, b, c])
       output = tf.math.add(new_a, new_b, name="output")
   ```
   
   Profiling result
   
   ```shell
               Type  Time(%)      Time     Calls       Avg       Min       Max  
Name
    GPU activities:   34.22%  101.92us        50  2.0380us  1.9840us  2.4640us  
fused_sum_kernel0
                      18.60%  55.392us        48  1.1540us  1.1200us  1.1840us  
fused_add_kernel0
                       9.53%  28.383us        25  1.1350us  1.1200us  1.1840us  
fused_less_less_logical_or_min_kernel0
                       9.53%  28.382us        25  1.1350us  1.1190us  1.2480us  
fused_less_less_logical_or_min_kernel2
                       9.51%  28.319us        25  1.1320us  1.1190us  1.1520us  
fused_less_less_logical_or_min_kernel1
                       9.41%  28.032us        25  1.1210us  1.0880us  1.1520us  
fused_less_less_logical_or_min_kernel3
                       7.96%  23.712us        26     912ns     864ns  1.2160us  
[CUDA memcpy DtoH]
                       0.86%  2.5600us         3     853ns     736ns  1.0560us  
[CUDA memcpy HtoD]
                       0.39%  1.1520us         1  1.1520us  1.1520us  1.1520us  
fused_add_1_kernel0
         API calls:   96.94%  87.460ms         2  43.730ms  38.606ms  48.855ms  
cuModuleLoadData
                       1.34%  1.2082ms       128  9.4390us  2.8270us  717.65us  
cudaMalloc
                       1.09%  987.12us       199  4.9600us  3.7170us  37.780us  
cuLaunchKernel
                       0.42%  380.92us        29  13.135us  4.9080us  59.369us  
cudaMemcpy
                       0.13%  117.01us       283     413ns     268ns  11.313us  
cudaSetDevice
                       0.06%  57.765us       199     290ns     226ns  4.0520us  
cudaGetDevice
                       0.01%  8.8010us         2  4.4000us  2.0600us  6.7410us  
cudaStreamSynchronize
                       0.00%  3.3900us         7     484ns     262ns     982ns  
cuModuleGetFunction
   ```
   
   The corresponding tir looks like
   
   ```
   primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, 
T_identity_1: handle) -> ()
     attr = {"global_symbol": "fused_less_less_logical_or_min", "tir.noalias": 
True}
     buffers = {T_identity: Buffer(T_identity_2: Pointer(int8), bool, [], []),
                placeholder_2: Buffer(placeholder_6: Pointer(float32), float32, 
[], []),
                placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, 
[], []),
                placeholder: Buffer(placeholder_8: Pointer(float32), float32, 
[], [])}
     buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, 
placeholder_5: placeholder_2, T_identity_1: T_identity} {
     attr [T_less: Pointer(int8)] "storage_scope" = "global";
     allocate(T_less, int8, [1]);
     attr [T_less_1: Pointer(int8)] "storage_scope" = "global";
     allocate(T_less_1, int8, [1]) {
       attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", 
"blockIdx.x")] "thread_extent" = 1;
       attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", 
"threadIdx.x")] "thread_extent" = 1;
       T_less[0] = cast(int8, ((float32*)placeholder_8[0] < 
(float32*)placeholder_7[0]))
       attr [IterVar(blockIdx.x_1: int32, (nullptr), "ThreadIndex", 
"blockIdx.x")] "thread_extent" = 1;
       attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", 
"threadIdx.x")] "thread_extent" = 1;
       T_less_1[0] = cast(int8, ((float32*)placeholder_6[0] < 
(float32*)placeholder_7[0]))
       attr [IterVar(blockIdx.x_2: int32, (nullptr), "ThreadIndex", 
"blockIdx.x")] "thread_extent" = 1;
       attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", 
"threadIdx.x")] "thread_extent" = 1;
       T_less[0] = cast(int8, (cast(bool, (int8*)T_less[0]) || cast(bool, 
(int8*)T_less_1[0])))
       attr [IterVar(blockIdx.x_3: int32, (nullptr), "ThreadIndex", 
"blockIdx.x")] "thread_extent" = 1;
       attr [IterVar(threadIdx.x_3: int32, (nullptr), "ThreadIndex", 
"threadIdx.x")] "thread_extent" = 1;
       T_identity_2[0] = cast(int8, cast(bool, (int8*)T_less[0]))
     }
   }
   ```
   
   To remove redundant CUDA kernels, I changed the schedule impl in 
cuda.reduction by adding a specific "EnableAutoInline" condition for 
auto-inlining the injective operations. But I don't know if there is a better 
way to solve this problem. Welcome for more suggestions and discussion.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to