sunggg opened a new pull request, #14548: URL: https://github.com/apache/tvm/pull/14548
This PR brings dynamic strided slice, which will be the first data-dependent op in Unity. ### Overview It consists of three parts (Op, TOPI, Legalization) and their test cases. Data-dependent ops, like dynamic strided slice, could be tricky when we cannot automatically deduce their output shape. In such cases, we cannot lower them since TE infra requires a concrete output shape, which should be defined with symbolic variables at least. Therefore, manual shape function registration is inevitable for those operators to let the compiler know how to compute their output shapes. With this PR, users can register the shape func in TOPI and insert it with match cast mechanism during the legalization. It's worth noting that for data-dependent ops, current TOPI creates symbolic variables whenever the shape value of certain dimension is unknown, then later mechanism handles the binding and so on. (see [link](https://github.com/apache/tvm/blob/unity/include/tvm/topi/transform.h#L672)) However, with this PR, the legalizer would be in charge of creating symbolic variables and binding them explicitly, and then pass the output shape to TOPI so that TOPI can simply use it to define its compute. ```Python @register_legalize("relax.dynamic_strided_slice") def _dynamic_strided_slice(bb: BlockBuilder, call: Call) -> Expr: # 1. Insert shape function output_shape = bb.normalize( bb.call_te( topi.shape_func_dynamic_strided_slice, call.args[0], call.args[1], call.args[2], call.args[3], ) ) # 2. Convert tensor to shape and match cast with new symbolic vars # Get shape length ndim = int(output_shape.struct_info.shape[0]) output_shape = bb.emit( Call( ExternFunc("vm.builtin.tensor_to_shape"), [output_shape], sinfo_args=[ShapeStructInfo(ndim=ndim)], ) ) output_shape_vars = [tir.Var("s", "int64") for i in range(ndim)] bb.match_cast(output_shape, ShapeStructInfo(output_shape_vars)) # 3. Pass the output shape vars to TOPI return bb.call_te( topi.dynamic_strided_slice, call.args[0], call.args[1], call.args[2], call.args[3], output_shape=output_shape_vars, ) ``` Through the internal discussion with @junrushao, we confirmed that this should comply with WIP PR https://github.com/apache/tvm/pull/14278. Also, since this requires the change in existing `topi::dynamic_strided_slice` (see [link](https://github.com/apache/tvm/blob/unity/include/tvm/topi/transform.h#L709)), this PR creates relax namespace and implements its own version. ### Notes * Currently, in relax, we already have `relax.strided_slice` op that covers non-data-dependent scenarios. This is still useful since it can cover the current limitation of `relax.dynamic_strided_slice` op: its shape analysis may not be informative enough for its users like `tensor_to_shape`, which requires the known integer shape at compile-time. We will revisit their unification when we have a better understanding in the future. * This PR adapts the Relay's `topi::dynamic_strided_slice` to relax standards without fixing its current limitation. Therefore, it expects to perform preprocessing for `begin/end/strides` tensors to make them have the equal length with the dimension of `data` tensor. PR https://github.com/apache/tvm/pull/14493 would be helpful for such pre-processing. ### Discussion * Currently, I named it "dynamic" strided slice only for simplicity and following the tradition. But, strictly speaking, data-dependent strided slice is more accurate since `relax.strided_slice` can already cover shape dynamism for `data` tensor. Any suggestion would be very helpful. * To maintain the relax version of TOPI, I created the relax namespace for now. Would there be a better way? cc. @yongwww @jwfromm @psrivas2 @slyubomirsky @junrushao @tqchen @MasterJH5574 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
