mseth10 commented on a change in pull request #18350:
URL: https://github.com/apache/incubator-mxnet/pull/18350#discussion_r427101732



##########
File path: src/c_api/c_api_symbolic.cc
##########
@@ -1383,47 +1394,78 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
   if (args_len || aux_len) {
     NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
     NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
-    Context default_ctx = 
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
-    mxnet::ShapeVector arg_shapes(args_len + aux_len);
-    nnvm::DTypeVector arg_dtypes(args_len + aux_len);
-    StorageTypeVector arg_stypes(args_len + aux_len);
-    size_t args_top = 0, aux_top = 0;
-    // loop over inputs to symbol in order and add to args/aux if mutable
-    for (size_t i = 0; i < num_forward_inputs; ++i) {
-      const uint32_t nid = indexed_graph.input_nodes().at(i);
-      if (mutable_nodes.count(nid)) {
-        CHECK_LT(aux_top, aux_len)
-          << "Cannot find aux '" << input_names[i] << "' in provided aux to 
optimize_for";
-        const auto &in_arg = *(in_aux_ptr[aux_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
-      } else {
-        CHECK_LT(args_top, args_len)
-          << "Cannot find arg '" << input_names[i] << "' in provided args to 
optimize_for";
-        const auto &in_arg = *(in_args_ptr[args_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
+    if (!skip_infer) {
+      Context default_ctx = 
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
+      mxnet::ShapeVector arg_shapes(args_len + aux_len);
+      nnvm::DTypeVector arg_dtypes(args_len + aux_len);
+      StorageTypeVector arg_stypes(args_len + aux_len);
+
+      // create the input shape, dtype and stype maps
+      std::unordered_map<std::string, mxnet::TShape> 
input_shape_map(num_input_shapes);
+      for (uint32_t i = 0; i < num_input_shapes; ++i) {
+        input_shape_map.emplace(input_shape_names[i],
+                    mxnet::TShape(input_shape_data + input_shape_idx[i],
+                    input_shape_data + input_shape_idx[i+1]));
+      }
+      std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
+      for (uint32_t i = 0; i < num_input_dtypes; ++i) {
+        input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
+      }
+      std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
+      for (uint32_t i = 0; i < num_input_stypes; ++i) {
+        input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
       }
-    }
 
-    g.attrs["context"] = std::make_shared<nnvm::any>(
-        exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
+      size_t args_top = 0, aux_top = 0;
+      // loop over inputs to symbol in order and add to args/aux if mutable
+      for (size_t i = 0; i < num_forward_inputs; ++i) {
+        const uint32_t nid = indexed_graph.input_nodes().at(i);
+        if (mutable_nodes.count(nid)) {
+          auto name = input_names[i];
+          CHECK_LT(aux_top, aux_len)
+            << "Cannot find aux '" << name << "' in provided aux to 
optimize_for";
+          if (in_aux_ptr[aux_top] != nullptr) {
+            const auto &in_arg = *(in_aux_ptr[aux_top]);
+            arg_shapes[i] = in_arg.shape();
+            arg_dtypes[i] = in_arg.dtype();
+            arg_stypes[i] = in_arg.storage_type();
+          } else {
+            auto it_shape = input_shape_map.find(name);
+            if (it_shape != input_shape_map.end()) {
+              arg_shapes[i] = it_shape->second;
+            }

Review comment:
       else clause? what if `name` not found in `input_shape_map`? do we error 
out?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to