junrushao1994 commented on a change in pull request #11760: [MXNET-684] Add 
`ifelse` operator
URL: https://github.com/apache/incubator-mxnet/pull/11760#discussion_r204130114
 
 

 ##########
 File path: src/operator/control_flow.cc
 ##########
 @@ -977,6 +913,342 @@ WhileLoopGradient(const nnvm::NodePtr& n, const 
std::vector<nnvm::NodeEntry>& og
   return entries;
 }
 
+struct IfelseParam : public dmlc::Parameter<IfelseParam> {
+  int num_args;
+  int num_outputs;
+  nnvm::Tuple<dim_t> cond_input_locs;
+  nnvm::Tuple<dim_t> then_input_locs;
+  nnvm::Tuple<dim_t> else_input_locs;
+  DMLC_DECLARE_PARAMETER(IfelseParam) {
+    DMLC_DECLARE_FIELD(num_args).set_lower_bound(3)
+    .describe("Number of input arguments, including cond, then and else as 
three symbol inputs.");
+    DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1)
+    .describe("The number of outputs of the subgraph.");
+    DMLC_DECLARE_FIELD(cond_input_locs)
+    .describe("The locations of cond's inputs in the given inputs.");
+    DMLC_DECLARE_FIELD(then_input_locs)
+    .describe("The locations of then's inputs in the given inputs.");
+    DMLC_DECLARE_FIELD(else_input_locs)
+    .describe("The locations of else's inputs in the given inputs.");
+  }
+};  // struct IfelseParam
+
+DMLC_REGISTER_PARAMETER(IfelseParam);
+
+class IfelseState {
+ public:
+  IfelseParam params;
+  CachedOpPtr cond_op;
+  LoopState then_branch;
+  LoopState else_branch;
+  int branch_selection;  // 1 if then branch; 0 if else branch; -1 if undefined
+
+  IfelseState(const IfelseParam &params,
+              const Symbol &cond,
+              const Symbol &then_sym,
+              const Symbol &else_sym):
+              params(params),
+              cond_op(LoopState::MakeSharedOp(cond)),
+              then_branch(then_sym),
+              else_branch(else_sym),
+              branch_selection(-1) {
+  }
+};
+
+static void IfelseComputeExCPU(const OpStatePtr& state_ptr,
+                               const OpContext& ctx,
+                               const std::vector<NDArray>& inputs,
+                               const std::vector<OpReqType>& req,
+                               const std::vector<NDArray>& outputs) {
+  // The argument `inputs' are loop_vars and other inputs
+  // loop_vars are stored in stored in `loop_vars_locs'
+  // The argument `outputs' are output and new_loop_vars
+  // [0: num_out_data) are outputs at each step.
+  // [num_out_data: ) are new_loop_vars
+  IfelseState &state = state_ptr.get_state<IfelseState>();
+  const IfelseParam& params = state.params;
+  // a helper function, converting std::vector<NDArray> to 
std::vector<NDArray*>
+  const auto to_ptr_vec = [](std::vector<NDArray> &in, std::vector<NDArray*> 
*out) {
+    out->clear();
+    out->reserve(in.size());
+    std::transform(std::begin(in),
+                   std::end(in),
+                   std::back_inserter(*out),
+                   [](NDArray &a) {return &a;});
+  };
+  // sanity checks
+  CHECK_EQ(inputs.size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
+  CHECK_EQ(outputs.size(), req.size());
+  // construct inputs and outputs for cond
+  std::vector<NDArray> cond_inputs;
+  std::vector<NDArray> cond_outputs = {NDArray()};
+  std::vector<NDArray*> cond_input_ptr;
+  std::vector<NDArray*> cond_output_ptr;
+  extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
+  to_ptr_vec(cond_inputs, &cond_input_ptr);
+  to_ptr_vec(cond_outputs, &cond_output_ptr);
+  int &branch_selection = state.branch_selection;
+  // run cond
+  state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr);
+  branch_selection = as_bool_scalar(*cond_output_ptr[0]);
+  // select the right branch
+  const nnvm::Tuple<dim_t> &func_input_locs = branch_selection
+                                            ? params.then_input_locs
+                                            : params.else_input_locs;
+  LoopState &loop_state = branch_selection
+                        ? state.then_branch
+                        : state.else_branch;
+  // extract inputs for the branch
+  std::vector<NDArray> func_inputs;
+  extract_by_loc(inputs, func_input_locs, &func_inputs);
+  loop_state.Forward(0, func_inputs, req, outputs, ctx.need_grad);
+}
+
+static void IfelseGradComputeExCPU(const OpStatePtr& state_ptr,
+                                   const OpContext& ctx,
+                                   const std::vector<NDArray>& inputs,
+                                   const std::vector<OpReqType>& _req,
+                                   const std::vector<NDArray>& outputs) {
+  IfelseState &state = state_ptr.get_state<IfelseState>();
+  const IfelseParam& params = state.params;
+  // sanity checks
+  CHECK_EQ(outputs.size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(outputs.size(), _req.size());
+  // select the right branch
+  int branch_selection = state.branch_selection;
+  CHECK_NE(branch_selection, -1);
+  const nnvm::Tuple<dim_t> &func_input_locs = branch_selection
+                                            ? params.then_input_locs
+                                            : params.else_input_locs;
+  LoopState &loop_state = branch_selection
+                        ? state.then_branch
+                        : state.else_branch;
+  // construct parameters
+  std::vector<NDArray> ograds(inputs.begin(), inputs.begin() + 
params.num_outputs);
+  std::vector<OpReqType> req;
+  extract_by_loc(_req, func_input_locs, &req);
+  std::vector<NDArray> igrads;
+  extract_by_loc(outputs, func_input_locs, &igrads);
+  loop_state.Backward(0, ograds, req, igrads);
+  loop_state.Cleanup();
+}
+
+static bool IfelseShape(const nnvm::NodeAttrs& attrs,
+                        std::vector<TShape> *in_shape,
+                        std::vector<TShape> *out_shape) {
+  using nnvm::ShapeVector;
+  const IfelseParam& params = nnvm::get<IfelseParam>(attrs.parsed);
+  static const std::function<bool(const TShape &)> is_udf = is_shape_udf;
+  // sanity checks
+  CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 3U);
+  CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+  CHECK_EQ(attrs.subgraphs[1]->outputs.size(), 
attrs.subgraphs[2]->outputs.size());
+  // infer shape for cond, then and else
+  auto infer_subg = [&params, in_shape, out_shape](std::shared_ptr<Symbol> 
subg,
+                                                   ShapeVector *_subg_out,
+                                                   const nnvm::Tuple<dim_t> 
&input_locs,
+                                                   bool fill_out_shape) {
+    // create subg_in
+    ShapeVector subg_in;
+    ShapeVector &subg_out = *_subg_out;
+    extract_by_loc(*in_shape, input_locs, &subg_in);
+    // create an indexed graph
+    nnvm::Graph g;
+    g.outputs = subg->outputs;
+    const auto& idx = g.indexed_graph();
+    // get input nodes
+    const auto &input_nids = idx.input_nodes();
+    // sanity checks
+    CHECK_EQ(input_nids.size(), subg_in.size());
+    CHECK_EQ(g.outputs.size(), subg_out.size());
+    CHECK_EQ(idx.input_nodes().size(), subg_in.size());
+    CHECK_EQ(idx.outputs().size(), subg_out.size());
+    // create empty shapes for inference
+    ShapeVector shapes(idx.num_node_entries());
+    // copy subg_in into shapes
+    for (size_t i = 0; i < subg_in.size(); ++i) {
+      auto eid = idx.entry_id(input_nids[i], 0);
+      shapes[eid] = subg_in[i];
+    }
+    // copy subg_out into shapes
+    for (size_t i = 0; i < subg_out.size(); ++i) {
+      auto eid = idx.entry_id(g.outputs[i]);
+      shapes[eid] = subg_out[i];
+    }
+    // copy done, call InferShape
+    g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
+    g = exec::InferShape(std::move(g));
+    // now `shapes' won't be used anymore, use new_shapes instead
+    const auto& new_shapes = g.GetAttr<ShapeVector>("shape");
+    // copy subg_in back to in_shape
+    for (size_t i = 0; i < subg_in.size(); ++i) {
+      auto eid = idx.entry_id(input_nids[i], 0);
+      auto g_out_shape = new_shapes[eid];
+      if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
+        // when the shape is not fully inferred
+        continue;
+      }
+      SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape);
+    }
+    if (!fill_out_shape) {
+      return true;
+    }
+    // copy subg_out back to out_shape
+    for (size_t i = 0; i < g.outputs.size(); ++i) {
+      auto eid = idx.entry_id(g.outputs[i]);
+      auto g_out_shape = new_shapes[eid];
+      if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) {
+        // when the shape is not fully inferred
+        continue;
+      }
+      SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape);
+    }
+    return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
+  };
+  ShapeVector cond_out_shape{TShape(1U)};  // this means: [(1, )]
+  ShapeVector then_out_shape(params.num_outputs);
+  ShapeVector else_out_shape(params.num_outputs);
+  bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, \
+                           params.cond_input_locs, false);
+  bool succ_1 = infer_subg(attrs.subgraphs[1], &then_out_shape, \
+                           params.then_input_locs, true);
+  bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \
+                           params.else_input_locs, true);
+  return succ_0 && succ_1 && succ_2;
 
 Review comment:
   My bad, fixed

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to