zheng-da commented on a change in pull request #8302: Refactor operators &
MKLDNN
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r162523057
##########
File path: src/operator/tensor/cast_storage.cc
##########
@@ -25,10 +25,50 @@
#include "./cast_storage-inl.h"
#include "../elemwise_op_common.h"
#include "../tensor/elemwise_unary_op.h"
+#include "../nn/mkldnn/mkldnn_base-inl.h"
namespace mxnet {
namespace op {
+#if MXNET_USE_MKLDNN == 1
+
+static inline int get_type_size(int dtype) {
+ MSHADOW_TYPE_SWITCH(dtype, DType, {return sizeof(DType);});
+ return -1;
+}
+
+void CastStorageMKLDnsImpl(const OpContext& ctx, const NDArray& src, const
NDArray &dst_arr) {
+ TBlob dns = dst_arr.data();
+ CHECK_EQ(ctx.run_ctx.ctx.dev_mask(), Context::kCPU);
+ CHECK(src.shape() == dns.shape_);
+ if (src.dtype() != dns.type_flag_) {
+ // If the input and output have different data types, we have to convert
+ // the source array into the default layout, cast the data type and copy
+ // data to the destination array.
+ const TBlob &src_blob = src.data();
+ CHECK(src.ctx() == dst_arr.ctx());
+ ndarray::Copy<cpu, cpu>(src.data(), &dns, src.ctx(), dst_arr.ctx(),
ctx.run_ctx);
+ } else {
+ // This converts the source data to the default format and write the data
to
+ // the destination directly.
+ std::vector<mkldnn::primitive> net;
+ auto src_mkldnn = src.GetMKLDNNData();
+ auto src_pd = src_mkldnn->get_primitive_desc();
+ auto def_format = GetDefaultFormat(src_pd.desc());
+ if (def_format != src_pd.desc().data.format) {
+ auto dst_pd = GetPrimitiveDesc(src_pd, def_format);
+ mkldnn::memory dst_mkldnn(dst_pd, dns.dptr_);
+ net.push_back(mkldnn::reorder(*src_mkldnn, dst_mkldnn));
+ mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+ } else {
+ const TBlob &src_blob = src.data();
+ memcpy(dns.dptr_, src_blob.dptr_, src.shape().Size() *
get_type_size(dns.type_flag_));
+ }
+ }
+}
+
+#endif
+
Review comment:
good point. I guess we should use kFComputeEx for all storage type?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services