anko-intel commented on a change in pull request #19749:
URL: https://github.com/apache/incubator-mxnet/pull/19749#discussion_r559734406



##########
File path: src/operator/tensor/amp_cast.cc
##########
@@ -41,25 +41,29 @@ static void AMPCastExCPU(const nnvm::NodeAttrs& attrs,
   if (req[0] == kWriteInplace) {
     return;
   }
-  mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
   auto data = inputs[0];
-  if (data.IsView() && data.IsMKLDNNData())
-    data = data.Reorder2Default();
-  const auto i_mem = data.GetMKLDNNData();
-  const size_t i_ndim = data.shape().ndim();
-  mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
-  for (size_t i = 0; i < i_ndim; i++) {
-    i_dims[i] = static_cast<int>(data.shape()[i]);
+  if (data.dtype() != mshadow::kFloat16) {

Review comment:
       I considered that. If  created isValidMKLDNNDataType() function could be 
used in many places like MKLDNNStorageType() for FInferStorageType it makes 
sense. But in this particular situation, amp_cast operator only accept 3 float 
types (see 
https://github.com/apache/incubator-mxnet/blob/v1.x/src/operator/tensor/amp_cast.h#L70
 ) so I just excluded float16  as not supported in MKLDNN.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to