ptrendx commented on pull request #20753: URL: https://github.com/apache/incubator-mxnet/pull/20753#issuecomment-979413104
Hi, the purpose of `amp_multicast` to cast the inputs of the ops that take multiple of them (think `add` for example) to be the same widest precision. We can't use multiple `amp_cast` for this since this cast happens before the type inference and so we do not know what that widest type is. An example - let's take `a + b` operation. If both `a` and `b` are `float16`, you do not want to insert any casts. If one of them (let's say `a`) is `float32` and the other (`b`) is `float16`, then you want to cast `b` to `float32`. Not sure what you mean by casting inputs to `amp_multicast` offline. There is an optimization for inference cases to cast parameters to `float16`, but those do not go into `amp_multicast` but to the regular `amp_cast` I believe. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
