ptrendx commented on pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#issuecomment-979413104


   Hi, the purpose of `amp_multicast` to cast the inputs of the ops that take 
multiple of them (think `add` for example) to be the same widest precision. We 
can't use multiple `amp_cast` for this since this cast happens before the type 
inference and so we do not know what that widest type is. An example - let's 
take `a + b` operation. If both `a` and `b` are `float16`, you do not want to 
insert any casts. If one of them (let's say `a`) is `float32` and the other 
(`b`) is `float16`, then you want to cast `b` to `float32`.
   
   Not sure what you mean by casting inputs to `amp_multicast` offline. There 
is an optimization for inference cases to cast parameters to `float16`, but 
those do not go into `amp_multicast` but to the regular `amp_cast` I believe.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to