zheng-da commented on a change in pull request #8302: Refactor operators & 
MKLDNN
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r164905297
 
 

 ##########
 File path: src/operator/nn/fully_connected.cc
 ##########
 @@ -23,58 +23,165 @@
  * \brief fully connect operator
 */
 #include "./fully_connected-inl.h"
+#include "./mkldnn/mkldnn_ops-inl.h"
+#include "./mkldnn/mkldnn_base-inl.h"
 #if MXNET_USE_NNPACK == 1
 #include "./nnpack/nnpack_fully_connected-inl.h"
 #endif  // MXNET_USE_NNPACK
 
 namespace mxnet {
 namespace op {
-template<>
-Operator* CreateOp<cpu>(FullyConnectedParam param, int dtype,
-                        std::vector<TShape> *in_shape,
-                        std::vector<TShape> *out_shape,
-                        Context ctx) {
-  Operator *op = NULL;
-#if MXNET_USE_NNPACK == 1
-  const size_t batch_size = (*in_shape)[0][0];
-  // nnp_fully_connected_inference will do optimization for batch-size = 1
-  // nnp_fully_connected_output will do optimization for batch-size > 1
-  switch (dtype) {
-  case mshadow::kFloat32:
-    return new NNPACKFullyConnectedOp<cpu, float>(param);
-  default:
-    break;
+
+static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
+                                std::vector<TShape> *in_shape,
+                                std::vector<TShape> *out_shape) {
+  const FullyConnectedParam& param = 
nnvm::get<FullyConnectedParam>(attrs.parsed);
+  using namespace mshadow;
+  if (!param.no_bias) {
+    CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
+  } else {
+    CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
+  }
+  CHECK_EQ(out_shape->size(), 1U);
+  TShape dshape = (*in_shape)[fullc::kData];
+  TShape oshape = (*out_shape)[0];
+  // require data to be known
+  if (dshape.ndim() ==  0) return false;
+
+  index_t num_input;
+  if (!param.flatten) {
+    num_input = dshape[dshape.ndim()-1];
+  } else {
+    num_input = dshape.ProdShape(1, dshape.ndim());
+  }
+  SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param.num_hidden, 
num_input));
+  if (!param.no_bias) {
+    SHAPE_ASSIGN_CHECK(*in_shape, fullc::kBias, Shape1(param.num_hidden));
+  }
+
+  if (!param.flatten) {
+    TShape result_shape(dshape);
+    result_shape[dshape.ndim()-1] = param.num_hidden;
+    SHAPE_ASSIGN_CHECK(*out_shape, 0, result_shape);
+  } else {
+    SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden));
+  }
+  if (oshape.ndim() != 0) {
+    dshape[0] = oshape[0];
+    SHAPE_ASSIGN_CHECK(*in_shape, fullc::kData, dshape);
+  }
+  return true;
+}
+
+void FullyConnectedCompute_CPU(const nnvm::NodeAttrs& attrs, const OpContext 
&ctx,
+    const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
+    const std::vector<NDArray> &outputs) {
+#if MXNET_USE_MKLDNN == 1
+  if (SupportMKLDNN(inputs[0])) {
+    MKLDNNFCForward(attrs, ctx, inputs, req, outputs);
+    return;
+  }
+#endif
+  std::vector<TBlob> in_blobs(inputs.size());
+  for (size_t i = 0; i < in_blobs.size(); i++)
+    in_blobs[i] = inputs[i].data();
+  std::vector<TBlob> out_blobs(outputs.size());
+  for (size_t i = 0; i < out_blobs.size(); i++)
+    out_blobs[i] = outputs[i].data();
+  FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
+}
+
+void FullyConnectedGradCompute_CPU(const nnvm::NodeAttrs& attrs,
+    const OpContext &ctx, const std::vector<NDArray> &inputs,
+    const std::vector<OpReqType> &req, const std::vector<NDArray> &outputs) {
+#if MXNET_USE_MKLDNN == 1
+  if (SupportMKLDNN(inputs[0])) {
+    MKLDNNFCBackward(attrs, ctx, inputs, req, outputs);
+    return;
   }
 #endif
-  switch (dtype) {
-  case mshadow::kFloat32:
-    op = new FullyConnectedOp<cpu, float>(param);
-    break;
-  case mshadow::kFloat64:
-    op = new FullyConnectedOp<cpu, double>(param);
-    break;
-  case mshadow::kFloat16:
-    LOG(FATAL) << "float16 fully connected layer is currently"
-                  "only supported by CuDNN version.";
-    break;
-  default:
-    LOG(FATAL) << "Unsupported type " << dtype;
+  std::vector<TBlob> in_blobs(inputs.size());
+  for (size_t i = 0; i < in_blobs.size(); i++)
+    in_blobs[i] = inputs[i].data();
+  std::vector<TBlob> out_blobs(outputs.size());
+  for (size_t i = 0; i < out_blobs.size(); i++)
+    out_blobs[i] = outputs[i].data();
+  FullyConnectedGradCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
+}
+
+static bool FullyConnectedType(const nnvm::NodeAttrs& attrs,
+                               std::vector<int> *in_type, std::vector<int> 
*out_type) {
+  CHECK_GE(in_type->size(), 1U);
+  return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
+      attrs, in_type, out_type, -1);
+}
+
+struct FullyConnectedGrad {
+  const char *op_name;
+  std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
+                                          const std::vector<nnvm::NodeEntry>& 
ograds) const {
+    std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
+    heads.push_back(n->inputs[fullc::kData]);
+    heads.push_back(n->inputs[fullc::kWeight]);
+    return MakeGradNode(op_name, n, heads, n->attrs.dict);
   }
+};
 
-  return op;
+inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
+                                 const int dev_mask,
+                                 DispatchMode* dispatch_mode,
+                                 std::vector<int> *in_attrs,
+                                 std::vector<int> *out_attrs) {
+  const FullyConnectedParam& param = 
nnvm::get<FullyConnectedParam>(attrs.parsed);
+  uint32_t in_expected = param.no_bias ? 2 : 3;
+  CHECK_EQ(in_attrs->size(), in_expected);
+  CHECK_EQ(out_attrs->size(), 1);
+
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+    (*out_attrs)[0] = kMKLDNNStorage;
+    return true;
+  }
+#endif
+  *dispatch_mode = DispatchMode::kFCompute;
+  (*out_attrs)[0] = kDefaultStorage;
 
 Review comment:
   Does FullyConnected support sparse?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to