PawelGlomski-Intel commented on pull request #20753: URL: https://github.com/apache/incubator-mxnet/pull/20753#issuecomment-992786074
@szha What is your take on this? When a model parameter is an input of the `amp_multicast` node, it will be cast to lp16 (with `cast_optional_params` set to `True`). `amp_multicast` node is added for ops from [this list](https://github.com/apache/incubator-mxnet/blob/40359ceda150ca75da6e45b1ea35d747ef53deac/python/mxnet/amp/lists/symbol_fp16.py#L641), so their inputs share one (most accurate) dtype. By the definition, such ops should only run on lp16 when all of its inputs are already in low precision, while currently, input parameters are always cast to lp16, even when they are all f32. I don't think it's intuitive. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
