masahi opened a new pull request #7385:
URL: https://github.com/apache/tvm/pull/7385


   This is my proposed solution to add While loop like feature to TIR, in the 
simplest, the least invasive way. It generalizes the For node termination 
condition from
   ```
   loop_var < extent
   ```
   to
   ```
   loop_var < extent && test
   ```
   
   Using this, we can write binary search as follows (see the complete test 
case, which implements numpy `searchsorted` function, 
[here](https://github.com/masahi/tvm/blob/a1b2c4a57e136186165e54bb1148faa44ad0a899/tests/python/unittest/test_tir_ir_builder.py#L176)).
   ```
   lo[0] = 0
   hi[0] = n
   v = Bptr[i]
   num_loop = int(np.log2(n)) + 1
   
   with ib.for_range(0, num_loop, test=(lo[0] < hi[0])) as _:
       mid = lo[0] + tvm.tir.floordiv(hi[0] - lo[0], 2).astype("int32")
       with ib.if_scope(Aptr[mid] < v):
           lo[0] = mid + 1
       with ib.else_scope():
           hi[0] = mid
   
   Cptr[i] = lo[0]
   ```
   My motivation was to improve GPU NMS performance using while loop, and it 
indeed did:
   
   NMS workload from PyTorch MaskRCNN:
   ```
   Without while loop (current main): 4.11 milli sec
   With while loop (my branch): 3.66 milli sec
   ```
   And a crazy 120000 box + 100 max_out_size NMS workload from TF MaskRCNN. The 
difference is huge because the # of iterations changed from 120000 to 100
   ```
   Without while loop (current main): 51.31 milli sec
   With while loop (my branch): 17.63 milli sec
   ```
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to