Lunderberg commented on issue #17254: URL: https://github.com/apache/tvm/issues/17254#issuecomment-2276664434
Looks like this is both a bug in the `ToMixedPrecision` pass, and a limitation in the well-formed checker. The bug is that the current implementation of `ToMixedPrecision` implicitly assumes that all tensors are float32 before the pass is applied. As a result, when the annotation `MixedPrecisionPolicyKind::kNever` (that is, never change the dtype of this operator) is encountered for `R.call_tir`, it casts the input/output of `R.call_tir` to float32. The limitation in the well-formed checker is that the callee in `R.call_tir` doesn't get inspected to identify this issue when it first occurs, so it only is caught later by `FuseTIR`. The `StructuralEqual()` check looks like it was intended to catch a mismatched TIR buffer shape, but serendipitously caught the mismatched dtype between the output of `fused_layer_norm_cast` and the input of `conv2d_cast_relu`. To fix this, I think the `ToMixedPrecision` pass will need updated behavior for `kNever`. Where currently, when an operator with mixed-precision policy of `kNever` will cast all inputs to `float32`, it should instead cast all inputs to the same dtype as they had prior to `ToMixedPrecision`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
