dpankratz opened a new pull request #5156: [Bugfix] Fixed rewrite_simplify to correctly cast integers known at compile time. URL: https://github.com/apache/incubator-tvm/pull/5156 ## Bug Casting values which are known at compile time results in strange behaviour. For example `1 * tir.Cast("bool", 77)` becomes `77` after the rewrite_simplify pass instead of `1`. However if that value was passed in at runtime then the expect result of `1` would be printed. ## Example ``` print(tvm.arith.Analyzer().rewrite_simplify(1 * tir.Cast('bool', 77))) #prints 77 a =te.var('a') shape = (1,) c = te.compute(shape,lambda i: 1 * tir.Cast('bool',a)) s = te.create_schedule([c.op]) f = tvm.build(s,[a,c]) c_tvm = tvm.nd.array(np.zeros(shape,dtype='int32')) f(77,c_tvm) assert c_tvm.asnumpy()[0] == 1 ``` ## Explanation When a `Cast(dtype, value)` call is used on a compile time value then it is replaced with an `IntImm` with the `dtype` and `value` from the call. However, the `value` is stored as a C++ `int64_t` and thus does not limit the `value` to a sensible range . If the resulting `IntImm` is later used in an expression that results in a type promotion unexpected behaviour such as the example above occurs where the `Cast` effectively does nothing. ## Fix My fix is to ensure that as the `IntImm` is created it takes the type bounds into account. After the fix the above example would behave as expected. ``` print(tvm.arith.Analyzer().rewrite_simplify(1 * tir.Cast('bool', 77))) #prints 1 ``` A review would be much appreciated! @tqchen @vinx13 @ZihengJiang
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
