apeforest commented on a change in pull request #15593: Large Index Support for 
Slice
URL: https://github.com/apache/incubator-mxnet/pull/15593#discussion_r310782505
 
 

 ##########
 File path: src/c_api/c_api_symbolic.cc
 ##########
 @@ -585,6 +586,96 @@ int MXSymbolInferShape(SymbolHandle sym,
   API_END();
 }
 
+template<typename dtype, typename stype, typename itype>
+inline void SymbolInferShape(const char** keys,
+                      mx_uint num_args,
+                      const dtype* arg_shape_data,
+                      const itype* arg_ind_ptr,
+                      const int** in_shape_ndim,
+                      const dtype*** in_shape_data,
+                      const int** out_shape_ndim,
+                      const dtype*** out_shape_data,
+                      const int** aux_shape_ndim,
+                      const dtype*** aux_shape_data,
+                      nnvm::Symbol* s,
+                      MXAPIThreadLocalEntry<dtype>* ret,
+                      stype* in_shape_size,
+                      stype* out_shape_size,
+                      stype* aux_shape_size,
+                      int* complete) {
+nnvm::Graph g = Symbol2Graph(*s);
+mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), 
mxnet::TShape());
+if (keys == nullptr && num_args != 0) {
+  std::vector < uint32_t > read_only_args = 
mxnet::ReadOnlyArgIndices(g.indexed_graph());
+  CHECK_LE(num_args, read_only_args.size());
+  for (mx_uint i = 0; i < num_args; ++i) {
+    arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(arg_shape_data + 
arg_ind_ptr[i],
+                                                         arg_shape_data + 
arg_ind_ptr[i + 1]);
+  }
+} else {
+  std::unordered_map<std::string, mxnet::TShape> kwargs;
+  for (mx_uint i = 0; i < num_args; ++i) {
+    kwargs[keys[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i],
+                                           arg_shape_data + arg_ind_ptr[i + 
1]);
+  }
+  mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape");
+}
+try {
+  g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), 
"__shape__");
+} catch (const mxnet::op::InferShapeError& err) {
+  throw dmlc::Error(err.msg);
+}
+// if use legacy shape definition, need to convert numpy shape to legacy shape
+mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
+if (!Imperative::Get()->is_np_shape()) {
+  common::ConvertToLegacyShape(&shapes);
+}
+// copy back
+CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), 
&(ret->aux_shapes));
+// copy data back
+//if mxnet::features.is_enabled(INT64_TENSOR_SIZE){
+MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes,
+                                                                
&(ret->arg_shape_ndim_ex),
+                                                                
&(ret->arg_shape_data_ex),
+                                                                
&(ret->arg_shape_buffer_ex));
+MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->out_shapes,
+                                                                
&(ret->out_shape_ndim_ex),
+                                                                
&(ret->out_shape_data_ex),
+                                                                
&(ret->out_shape_buffer_ex));
+MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes,
+                                                                
&(ret->aux_shape_ndim_ex),
+                                                                
&(ret->aux_shape_data_ex),
+                                                                
&(ret->aux_shape_buffer_ex));
+ *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex);
+ *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex);
+ *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex);
+//} else {
 
 Review comment:
   please remove commented code.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to