Lunderberg commented on issue #17254:
URL: https://github.com/apache/tvm/issues/17254#issuecomment-2276664434

   Looks like this is both a bug in the `ToMixedPrecision` pass, and a 
limitation in the well-formed checker.  The bug is that the current 
implementation of `ToMixedPrecision` implicitly assumes that all tensors are 
float32 before the pass is applied.  As a result, when the annotation 
`MixedPrecisionPolicyKind::kNever` (that is, never change the dtype of this 
operator) is encountered for `R.call_tir`, it casts the input/output of 
`R.call_tir` to float32.
   
   The limitation in the well-formed checker is that the callee in `R.call_tir` 
doesn't get inspected to identify this issue when it first occurs, so it only 
is caught later by `FuseTIR`.  The `StructuralEqual()` check looks like it was 
intended to catch a mismatched TIR buffer shape, but serendipitously caught the 
mismatched dtype between the output of `fused_layer_norm_cast` and the input of 
`conv2d_cast_relu`.
   
   To fix this, I think the `ToMixedPrecision` pass will need updated behavior 
for `kNever`.  Where currently, when an operator with mixed-precision policy of 
`kNever` will cast all inputs to `float32`, it should instead cast all inputs 
to the same dtype as they had prior to `ToMixedPrecision`.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to