wzh99 opened a new pull request, #13280: URL: https://github.com/apache/tvm/pull/13280
This PR adds a rank check for input tensor in type inference of `nn.instance_norm`. I explain the reasons as follows: First, according to the definition of [Instance Normalization](https://paperswithcode.com/method/instance-normalization), it only normalizes the data dimensions. Therefore, the input tensor must be of at least rank 3, and otherwise the operator will not produce meaningful results. Second, `nn.instance_norm` with tensor rank less than 2 leads to a problem in the `SimplifyInference` optimization pass. This pass finds the reduced axes before converting the operator to its lower-level computation definition: https://github.com/apache/tvm/blob/f15afd225140e2a501b8b6aa2def0fd94d31bc54/src/relay/transforms/simplify_inference.cc#L149-L152 If the rank of the input tensor is less than 3, `reduced_axes` is empty. According to `GetRealAxis`, all the dimensions are reduced: https://github.com/apache/tvm/blob/f15afd225140e2a501b8b6aa2def0fd94d31bc54/include/tvm/topi/reduction.h#L67-L70 This is problematic because we do not actually want batch and `axis` dimensions to be reduced. cc @masahi -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
