Kh4L commented on a change in pull request #18350:
URL: https://github.com/apache/incubator-mxnet/pull/18350#discussion_r427153786



##########
File path: src/c_api/c_api_symbolic.cc
##########
@@ -1383,47 +1394,78 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
   if (args_len || aux_len) {
     NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
     NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
-    Context default_ctx = 
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
-    mxnet::ShapeVector arg_shapes(args_len + aux_len);
-    nnvm::DTypeVector arg_dtypes(args_len + aux_len);
-    StorageTypeVector arg_stypes(args_len + aux_len);
-    size_t args_top = 0, aux_top = 0;
-    // loop over inputs to symbol in order and add to args/aux if mutable
-    for (size_t i = 0; i < num_forward_inputs; ++i) {
-      const uint32_t nid = indexed_graph.input_nodes().at(i);
-      if (mutable_nodes.count(nid)) {
-        CHECK_LT(aux_top, aux_len)
-          << "Cannot find aux '" << input_names[i] << "' in provided aux to 
optimize_for";
-        const auto &in_arg = *(in_aux_ptr[aux_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
-      } else {
-        CHECK_LT(args_top, args_len)
-          << "Cannot find arg '" << input_names[i] << "' in provided args to 
optimize_for";
-        const auto &in_arg = *(in_args_ptr[args_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
+    if (!skip_infer) {
+      Context default_ctx = 
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
+      mxnet::ShapeVector arg_shapes(args_len + aux_len);
+      nnvm::DTypeVector arg_dtypes(args_len + aux_len);
+      StorageTypeVector arg_stypes(args_len + aux_len);
+
+      // create the input shape, dtype and stype maps
+      std::unordered_map<std::string, mxnet::TShape> 
input_shape_map(num_input_shapes);
+      for (uint32_t i = 0; i < num_input_shapes; ++i) {
+        input_shape_map.emplace(input_shape_names[i],
+                    mxnet::TShape(input_shape_data + input_shape_idx[i],
+                    input_shape_data + input_shape_idx[i+1]));
+      }
+      std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
+      for (uint32_t i = 0; i < num_input_dtypes; ++i) {
+        input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
+      }
+      std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
+      for (uint32_t i = 0; i < num_input_stypes; ++i) {
+        input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
       }
-    }
 
-    g.attrs["context"] = std::make_shared<nnvm::any>(
-        exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
+      size_t args_top = 0, aux_top = 0;
+      // loop over inputs to symbol in order and add to args/aux if mutable
+      for (size_t i = 0; i < num_forward_inputs; ++i) {
+        const uint32_t nid = indexed_graph.input_nodes().at(i);
+        if (mutable_nodes.count(nid)) {
+          auto name = input_names[i];
+          CHECK_LT(aux_top, aux_len)
+            << "Cannot find aux '" << name << "' in provided aux to 
optimize_for";
+          if (in_aux_ptr[aux_top] != nullptr) {
+            const auto &in_arg = *(in_aux_ptr[aux_top]);
+            arg_shapes[i] = in_arg.shape();
+            arg_dtypes[i] = in_arg.dtype();
+            arg_stypes[i] = in_arg.storage_type();
+          } else {
+            auto it_shape = input_shape_map.find(name);
+            if (it_shape != input_shape_map.end()) {
+              arg_shapes[i] = it_shape->second;
+            }
+            auto it_type = input_dtype_map.find(name);
+            if (it_type != input_dtype_map.end()) {
+              arg_dtypes[i] = it_type->second;
+            }
+            it_type = input_stype_map.find(name);
+            if (it_type != input_stype_map.end()) {
+              arg_stypes[i] = it_type->second;
+            }
+          }
+          aux_top++;
+        } else {
+          CHECK_LT(args_top, args_len)
+            << "Cannot find arg '" << input_names[i] << "' in provided args to 
optimize_for";
+          if (in_args_ptr[args_top] != nullptr) {
+            const auto &in_arg = *(in_args_ptr[args_top]);
+            arg_shapes[i] = in_arg.shape();
+            arg_dtypes[i] = in_arg.dtype();
+            arg_stypes[i] = in_arg.storage_type();
+          }
+          args_top++;
+        }
+      }
 
-    // infer shapes
-    g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
-    // infer dtypes
-    g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
-    if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
-      common::HandleInferTypeError(num_forward_inputs, indexed_graph,

Review comment:
       The goal of this PR is to allow the user to be able to skip the 
inference pass, either explicitly with `skip_infer` or implicitly, relying on 
partial attribute inference.
   
   Some backends may require the attr inference, some others not, so the error 
should be handled by the backend.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to