cjolivier01 commented on a change in pull request #9677: Refactor operators and 
add MKLDNN
URL: https://github.com/apache/incubator-mxnet/pull/9677#discussion_r168576406
 
 

 ##########
 File path: tests/cpp/include/test_core_op.h
 ##########
 @@ -311,68 +339,166 @@ class CoreOpExecutor : public 
test::op::OperatorDataInitializer<DType>
       op_ = nnvm::Op::Get(op_name);
       CHECK_NOTNULL(op_);
 
+      std::map<int, const NDArray *> index2array;
+      nnvm::NodePtr bwd_node_ptr;
+      if (backward_for_op) {
+        bwd_node_ptr = backward_for_op->CalcBackwardPass(&index2array);
+      }
+
       // Set up forward
       attrs_ = ParseAttrs(op_, args);
 
-      const int num_inputs = op_->num_inputs;
+      int num_inputs = op_->num_inputs;
+      if (op_->get_num_inputs) {
+        num_inputs = op_->get_num_inputs(attrs_);
+      } else if (backward_for_op) {
+        if (bwd_node_ptr) {
+          num_inputs = static_cast<int>(bwd_node_ptr->inputs.size());
+        }
+      }
 
       if (!inputs.empty()) {
         CHECK_EQ(inputs.size(), static_cast<size_t>(num_inputs));
       }
 
-      int inferred_num_outputs, num_visible_outputs;
+      int inferred_num_outputs /*, num_visible_outputs*/;
 
-      imperative::SetNumOutputs(op_, attrs_, num_inputs, &inferred_num_outputs,
-                                &num_visible_outputs);
+      if (op_->get_num_outputs) {
+        inferred_num_outputs = op_->get_num_outputs(attrs_);
+      } else {
+        inferred_num_outputs = op_->num_outputs;
+      }
 
       // Generic, all shapes the same. Probably this will need to be adjusted 
for more complex
       // operators such as dot
-      std::vector<TShape> shapes;
-      for (size_t i = 0, n = std::max(num_visible_outputs, num_inputs); i < n; 
++i) {
-        shapes.emplace_back(i < input_shapes_.size() ? input_shapes_[i]
-                                                  : 
input_shapes_[input_shapes_.size() - 1]);
+      std::vector<nnvm::TShape> input_shapes;
+      if (!input_shapes_.empty()) {
+        for (size_t i = 0, n = num_inputs; i < n; ++i) {
+          input_shapes.emplace_back(i < input_shapes_.size() ? input_shapes_[i]
+                                                             : 
input_shapes_[input_shapes_.size()
+                                                                             - 
1]);
+        }
       }
       std::vector<NDArray *> inputs_p, outputs_p;
 
       if (!outputs.empty()) {
-        CHECK_EQ(outputs.size(), static_cast<size_t>(num_visible_outputs));
+        CHECK_EQ(outputs.size(), static_cast<size_t>(inferred_num_outputs));
       }
 
       inputs_.reserve(num_inputs);
       inputs_p.reserve(num_inputs);
-      outputs_.reserve(num_visible_outputs);
-      outputs_p.reserve(num_visible_outputs);
+      outputs_.reserve(inferred_num_outputs);
+      outputs_p.reserve(inferred_num_outputs);
+
+      std::vector<int> input_types;
+      input_types.reserve(num_inputs);
+      std::vector<int> output_types;
+      output_types.reserve(inferred_num_outputs);
+
+      static auto& finfer_type = Op::GetAttr<nnvm::FInferType>("FInferType");
+      if (finfer_type.count(op_)) {
+        input_types.resize(num_inputs, -1);
+        input_types[0] = default_dtype();  // Set first input to default type
+        output_types.resize(inferred_num_outputs, -1);
+        finfer_type[op_](attrs_, &input_types, &output_types);
+        CHECK_EQ(input_types.size(), num_inputs);
+        CHECK_EQ(output_types.size(), inferred_num_outputs);
+      } else {
+        if (backward_for_op) {
+          if (bwd_node_ptr) {
+            CHECK_EQ(bwd_node_ptr->inputs.size(), num_inputs);
+            input_types.resize(bwd_node_ptr->inputs.size(), -1);
+            for (size_t i = 0; i < num_inputs; ++i) {
+              const int map_key = bwd_node_ptr->inputs[i].index;
+              CHECK(index2array.find(map_key) != index2array.end());
+              const int dtype = index2array[map_key]->dtype();
+              input_types[i] = dtype;
+            }
+            for (const auto &fwd_inp : backward_for_op->inputs()) {
+              const int dtype = fwd_inp.data().type_flag_;
+              output_types.emplace_back(dtype);
+            }
+          } else {
+            for (size_t x = 0; x < num_inputs; ++x) {
+              input_types.emplace_back(default_dtype());
+            }
+            for (const auto &fwd_inp : backward_for_op->inputs()) {
+              const int dtype = fwd_inp.data().type_flag_;
+              output_types.emplace_back(dtype);
+            }
+          }
+        } else {
+          CHECK(false);  // above always true?
+          for (size_t x = 0; x < num_inputs; ++x) {
 
 Review comment:
   this is a notification for the developer that this is possible.  it's test 
code.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to