masahi opened a new pull request #7385: URL: https://github.com/apache/tvm/pull/7385
This is my proposed solution to add While loop like feature to TIR, in the simplest, the least invasive way. It generalizes the For node termination condition from ``` loop_var < extent ``` to ``` loop_var < extent && test ``` Using this, we can write binary search as follows (see the complete test case, which implements numpy `searchsorted` function, [here](https://github.com/masahi/tvm/blob/a1b2c4a57e136186165e54bb1148faa44ad0a899/tests/python/unittest/test_tir_ir_builder.py#L176)). ``` lo[0] = 0 hi[0] = n v = Bptr[i] num_loop = int(np.log2(n)) + 1 with ib.for_range(0, num_loop, test=(lo[0] < hi[0])) as _: mid = lo[0] + tvm.tir.floordiv(hi[0] - lo[0], 2).astype("int32") with ib.if_scope(Aptr[mid] < v): lo[0] = mid + 1 with ib.else_scope(): hi[0] = mid Cptr[i] = lo[0] ``` My motivation was to improve GPU NMS performance using while loop, and it indeed did: NMS workload from PyTorch MaskRCNN: ``` Without while loop (current main): 4.11 milli sec With while loop (my branch): 3.66 milli sec ``` And a crazy 120000 box + 100 max_out_size NMS workload from TF MaskRCNN. The difference is huge because the # of iterations changed from 120000 to 100 ``` Without while loop (current main): 51.31 milli sec With while loop (my branch): 17.63 milli sec ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
