Kh4L commented on a change in pull request #18350:
URL: https://github.com/apache/incubator-mxnet/pull/18350#discussion_r427153786
##########
File path: src/c_api/c_api_symbolic.cc
##########
@@ -1383,47 +1394,78 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
if (args_len || aux_len) {
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
- Context default_ctx =
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
- mxnet::ShapeVector arg_shapes(args_len + aux_len);
- nnvm::DTypeVector arg_dtypes(args_len + aux_len);
- StorageTypeVector arg_stypes(args_len + aux_len);
- size_t args_top = 0, aux_top = 0;
- // loop over inputs to symbol in order and add to args/aux if mutable
- for (size_t i = 0; i < num_forward_inputs; ++i) {
- const uint32_t nid = indexed_graph.input_nodes().at(i);
- if (mutable_nodes.count(nid)) {
- CHECK_LT(aux_top, aux_len)
- << "Cannot find aux '" << input_names[i] << "' in provided aux to
optimize_for";
- const auto &in_arg = *(in_aux_ptr[aux_top++]);
- arg_shapes[i] = in_arg.shape();
- arg_dtypes[i] = in_arg.dtype();
- arg_stypes[i] = in_arg.storage_type();
- } else {
- CHECK_LT(args_top, args_len)
- << "Cannot find arg '" << input_names[i] << "' in provided args to
optimize_for";
- const auto &in_arg = *(in_args_ptr[args_top++]);
- arg_shapes[i] = in_arg.shape();
- arg_dtypes[i] = in_arg.dtype();
- arg_stypes[i] = in_arg.storage_type();
+ if (!skip_infer) {
+ Context default_ctx =
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
+ mxnet::ShapeVector arg_shapes(args_len + aux_len);
+ nnvm::DTypeVector arg_dtypes(args_len + aux_len);
+ StorageTypeVector arg_stypes(args_len + aux_len);
+
+ // create the input shape, dtype and stype maps
+ std::unordered_map<std::string, mxnet::TShape>
input_shape_map(num_input_shapes);
+ for (uint32_t i = 0; i < num_input_shapes; ++i) {
+ input_shape_map.emplace(input_shape_names[i],
+ mxnet::TShape(input_shape_data + input_shape_idx[i],
+ input_shape_data + input_shape_idx[i+1]));
+ }
+ std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
+ for (uint32_t i = 0; i < num_input_dtypes; ++i) {
+ input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
+ }
+ std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
+ for (uint32_t i = 0; i < num_input_stypes; ++i) {
+ input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
}
- }
- g.attrs["context"] = std::make_shared<nnvm::any>(
- exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
+ size_t args_top = 0, aux_top = 0;
+ // loop over inputs to symbol in order and add to args/aux if mutable
+ for (size_t i = 0; i < num_forward_inputs; ++i) {
+ const uint32_t nid = indexed_graph.input_nodes().at(i);
+ if (mutable_nodes.count(nid)) {
+ auto name = input_names[i];
+ CHECK_LT(aux_top, aux_len)
+ << "Cannot find aux '" << name << "' in provided aux to
optimize_for";
+ if (in_aux_ptr[aux_top] != nullptr) {
+ const auto &in_arg = *(in_aux_ptr[aux_top]);
+ arg_shapes[i] = in_arg.shape();
+ arg_dtypes[i] = in_arg.dtype();
+ arg_stypes[i] = in_arg.storage_type();
+ } else {
+ auto it_shape = input_shape_map.find(name);
+ if (it_shape != input_shape_map.end()) {
+ arg_shapes[i] = it_shape->second;
+ }
+ auto it_type = input_dtype_map.find(name);
+ if (it_type != input_dtype_map.end()) {
+ arg_dtypes[i] = it_type->second;
+ }
+ it_type = input_stype_map.find(name);
+ if (it_type != input_stype_map.end()) {
+ arg_stypes[i] = it_type->second;
+ }
+ }
+ aux_top++;
+ } else {
+ CHECK_LT(args_top, args_len)
+ << "Cannot find arg '" << input_names[i] << "' in provided args to
optimize_for";
+ if (in_args_ptr[args_top] != nullptr) {
+ const auto &in_arg = *(in_args_ptr[args_top]);
+ arg_shapes[i] = in_arg.shape();
+ arg_dtypes[i] = in_arg.dtype();
+ arg_stypes[i] = in_arg.storage_type();
+ }
+ args_top++;
+ }
+ }
- // infer shapes
- g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
- // infer dtypes
- g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
- if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
- common::HandleInferTypeError(num_forward_inputs, indexed_graph,
Review comment:
The goal of this PR is to allow the user to be able to skip the
inference pass, either explicitly with `skip_infer` or implicitly, relying on
partial attribute inference.
Some backends may require the attr inference, some others not, so the error
should be handled by the backend.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]