PawelGlomski-Intel edited a comment on pull request #20753: URL: https://github.com/apache/incubator-mxnet/pull/20753#issuecomment-979841318
> Hi, the purpose of `amp_multicast` to cast the inputs of the ops that take multiple of them (think `add` for example) to be the same widest precision. We can't use multiple `amp_cast` for this since this cast happens before the type inference and so we do not know what that widest type is. An example - let's take `a + b` operation. If both `a` and `b` are `float16`, you do not want to insert any casts. If one of them (let's say `a`) is `float32` and the other (`b`) is `float16`, then you want to cast `b` to `float32`. > > Not sure what you mean by casting inputs to `amp_multicast` offline. There is an optimization for inference cases to cast parameters to `float16`, but those do not go into `amp_multicast` but to the regular `amp_cast` I believe. Thanks a lot, I thought that was the case with `amp_multicast` but wasn't 100% sure. Regarding the inference optimization - it indeed also applies to the `amp_multicast`, and it is even tested. I removed this as it seemed unlogical to me, so now these tests fail. [Here](https://github.com/apache/incubator-mxnet/pull/15118#issuecomment-506105272) is a comment about adding this. Was this an incorrect approach and my current version is correct? Here is one of the tests (BTW, the dtype of a variable doesn't matter at all here, it will always cast these parameters to fp16): https://github.com/apache/incubator-mxnet/blob/29ace886946941527047dc5deebe5b4b85b5e4cb/tests/python/gpu/test_amp.py#L133-L139 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
