anko-intel commented on a change in pull request #19749:
URL: https://github.com/apache/incubator-mxnet/pull/19749#discussion_r559734406
##########
File path: src/operator/tensor/amp_cast.cc
##########
@@ -41,25 +41,29 @@ static void AMPCastExCPU(const nnvm::NodeAttrs& attrs,
if (req[0] == kWriteInplace) {
return;
}
- mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
auto data = inputs[0];
- if (data.IsView() && data.IsMKLDNNData())
- data = data.Reorder2Default();
- const auto i_mem = data.GetMKLDNNData();
- const size_t i_ndim = data.shape().ndim();
- mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
- for (size_t i = 0; i < i_ndim; i++) {
- i_dims[i] = static_cast<int>(data.shape()[i]);
+ if (data.dtype() != mshadow::kFloat16) {
Review comment:
I considered that. If created isValidMKLDNNDataType() function could be
used in many places like MKLDNNStorageType() for FInferStorageType it makes
sense. But in this particular situation, amp_cast operator only accept 3 float
types (see
https://github.com/apache/incubator-mxnet/blob/v1.x/src/operator/tensor/amp_cast.h#L70
) so I just excluded float16 as not supported in MKLDNN.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]