Kh4L commented on a change in pull request #18350:
URL: https://github.com/apache/incubator-mxnet/pull/18350#discussion_r427159560



##########
File path: src/c_api/c_api_symbolic.cc
##########
@@ -1383,47 +1394,78 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
   if (args_len || aux_len) {
     NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
     NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
-    Context default_ctx = 
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
-    mxnet::ShapeVector arg_shapes(args_len + aux_len);
-    nnvm::DTypeVector arg_dtypes(args_len + aux_len);
-    StorageTypeVector arg_stypes(args_len + aux_len);
-    size_t args_top = 0, aux_top = 0;
-    // loop over inputs to symbol in order and add to args/aux if mutable
-    for (size_t i = 0; i < num_forward_inputs; ++i) {
-      const uint32_t nid = indexed_graph.input_nodes().at(i);
-      if (mutable_nodes.count(nid)) {
-        CHECK_LT(aux_top, aux_len)
-          << "Cannot find aux '" << input_names[i] << "' in provided aux to 
optimize_for";
-        const auto &in_arg = *(in_aux_ptr[aux_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
-      } else {
-        CHECK_LT(args_top, args_len)
-          << "Cannot find arg '" << input_names[i] << "' in provided args to 
optimize_for";
-        const auto &in_arg = *(in_args_ptr[args_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
+    if (!skip_infer) {
+      Context default_ctx = 
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
+      mxnet::ShapeVector arg_shapes(args_len + aux_len);
+      nnvm::DTypeVector arg_dtypes(args_len + aux_len);
+      StorageTypeVector arg_stypes(args_len + aux_len);
+
+      // create the input shape, dtype and stype maps
+      std::unordered_map<std::string, mxnet::TShape> 
input_shape_map(num_input_shapes);
+      for (uint32_t i = 0; i < num_input_shapes; ++i) {
+        input_shape_map.emplace(input_shape_names[i],
+                    mxnet::TShape(input_shape_data + input_shape_idx[i],
+                    input_shape_data + input_shape_idx[i+1]));
+      }
+      std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
+      for (uint32_t i = 0; i < num_input_dtypes; ++i) {
+        input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
+      }
+      std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
+      for (uint32_t i = 0; i < num_input_stypes; ++i) {
+        input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
       }
-    }
 
-    g.attrs["context"] = std::make_shared<nnvm::any>(
-        exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
+      size_t args_top = 0, aux_top = 0;
+      // loop over inputs to symbol in order and add to args/aux if mutable
+      for (size_t i = 0; i < num_forward_inputs; ++i) {
+        const uint32_t nid = indexed_graph.input_nodes().at(i);
+        if (mutable_nodes.count(nid)) {
+          auto name = input_names[i];
+          CHECK_LT(aux_top, aux_len)
+            << "Cannot find aux '" << name << "' in provided aux to 
optimize_for";
+          if (in_aux_ptr[aux_top] != nullptr) {
+            const auto &in_arg = *(in_aux_ptr[aux_top]);
+            arg_shapes[i] = in_arg.shape();
+            arg_dtypes[i] = in_arg.dtype();
+            arg_stypes[i] = in_arg.storage_type();
+          } else {
+            auto it_shape = input_shape_map.find(name);
+            if (it_shape != input_shape_map.end()) {
+              arg_shapes[i] = it_shape->second;
+            }

Review comment:
       This logic should be in the other close (for `args`, not aux). I just 
updated it

##########
File path: src/c_api/c_api_symbolic.cc
##########
@@ -1383,47 +1394,78 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
   if (args_len || aux_len) {
     NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
     NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
-    Context default_ctx = 
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
-    mxnet::ShapeVector arg_shapes(args_len + aux_len);
-    nnvm::DTypeVector arg_dtypes(args_len + aux_len);
-    StorageTypeVector arg_stypes(args_len + aux_len);
-    size_t args_top = 0, aux_top = 0;
-    // loop over inputs to symbol in order and add to args/aux if mutable
-    for (size_t i = 0; i < num_forward_inputs; ++i) {
-      const uint32_t nid = indexed_graph.input_nodes().at(i);
-      if (mutable_nodes.count(nid)) {
-        CHECK_LT(aux_top, aux_len)
-          << "Cannot find aux '" << input_names[i] << "' in provided aux to 
optimize_for";
-        const auto &in_arg = *(in_aux_ptr[aux_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
-      } else {
-        CHECK_LT(args_top, args_len)
-          << "Cannot find arg '" << input_names[i] << "' in provided args to 
optimize_for";
-        const auto &in_arg = *(in_args_ptr[args_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
+    if (!skip_infer) {
+      Context default_ctx = 
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
+      mxnet::ShapeVector arg_shapes(args_len + aux_len);
+      nnvm::DTypeVector arg_dtypes(args_len + aux_len);
+      StorageTypeVector arg_stypes(args_len + aux_len);
+
+      // create the input shape, dtype and stype maps
+      std::unordered_map<std::string, mxnet::TShape> 
input_shape_map(num_input_shapes);
+      for (uint32_t i = 0; i < num_input_shapes; ++i) {
+        input_shape_map.emplace(input_shape_names[i],
+                    mxnet::TShape(input_shape_data + input_shape_idx[i],
+                    input_shape_data + input_shape_idx[i+1]));
+      }
+      std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
+      for (uint32_t i = 0; i < num_input_dtypes; ++i) {
+        input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
+      }
+      std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
+      for (uint32_t i = 0; i < num_input_stypes; ++i) {
+        input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
       }
-    }
 
-    g.attrs["context"] = std::make_shared<nnvm::any>(
-        exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
+      size_t args_top = 0, aux_top = 0;
+      // loop over inputs to symbol in order and add to args/aux if mutable
+      for (size_t i = 0; i < num_forward_inputs; ++i) {
+        const uint32_t nid = indexed_graph.input_nodes().at(i);
+        if (mutable_nodes.count(nid)) {
+          auto name = input_names[i];
+          CHECK_LT(aux_top, aux_len)
+            << "Cannot find aux '" << name << "' in provided aux to 
optimize_for";
+          if (in_aux_ptr[aux_top] != nullptr) {
+            const auto &in_arg = *(in_aux_ptr[aux_top]);
+            arg_shapes[i] = in_arg.shape();
+            arg_dtypes[i] = in_arg.dtype();
+            arg_stypes[i] = in_arg.storage_type();
+          } else {
+            auto it_shape = input_shape_map.find(name);
+            if (it_shape != input_shape_map.end()) {
+              arg_shapes[i] = it_shape->second;
+            }
+            auto it_type = input_dtype_map.find(name);
+            if (it_type != input_dtype_map.end()) {
+              arg_dtypes[i] = it_type->second;
+            }
+            it_type = input_stype_map.find(name);
+            if (it_type != input_stype_map.end()) {
+              arg_stypes[i] = it_type->second;
+            }
+          }
+          aux_top++;
+        } else {
+          CHECK_LT(args_top, args_len)
+            << "Cannot find arg '" << input_names[i] << "' in provided args to 
optimize_for";
+          if (in_args_ptr[args_top] != nullptr) {
+            const auto &in_arg = *(in_args_ptr[args_top]);
+            arg_shapes[i] = in_arg.shape();
+            arg_dtypes[i] = in_arg.dtype();
+            arg_stypes[i] = in_arg.storage_type();
+          }

Review comment:
       AFAIK, we don't need it here, as input can't be in `aux` but only in 
`args`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to