zheng-da commented on a change in pull request #11760: [MXNET-684] Add `ifelse` 
operator
URL: https://github.com/apache/incubator-mxnet/pull/11760#discussion_r203925292
 
 

 ##########
 File path: src/operator/control_flow.cc
 ##########
 @@ -977,6 +913,342 @@ WhileLoopGradient(const nnvm::NodePtr& n, const 
std::vector<nnvm::NodeEntry>& og
   return entries;
 }
 
+struct IfelseParam : public dmlc::Parameter<IfelseParam> {
+  int num_args;
+  int num_outputs;
+  nnvm::Tuple<dim_t> cond_input_locs;
+  nnvm::Tuple<dim_t> then_input_locs;
+  nnvm::Tuple<dim_t> else_input_locs;
+  DMLC_DECLARE_PARAMETER(IfelseParam) {
+    DMLC_DECLARE_FIELD(num_args).set_lower_bound(3)
+    .describe("Number of input arguments, including cond, then and else as 
three symbol inputs.");
+    DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1)
+    .describe("The number of outputs of the subgraph.");
+    DMLC_DECLARE_FIELD(cond_input_locs)
+    .describe("The locations of cond's inputs in the given inputs.");
+    DMLC_DECLARE_FIELD(then_input_locs)
+    .describe("The locations of then's inputs in the given inputs.");
+    DMLC_DECLARE_FIELD(else_input_locs)
+    .describe("The locations of else's inputs in the given inputs.");
+  }
+};  // struct IfelseParam
+
+DMLC_REGISTER_PARAMETER(IfelseParam);
+
+class IfelseState {
+ public:
+  IfelseParam params;
+  CachedOpPtr cond_op;
+  LoopState then_branch;
+  LoopState else_branch;
+  int branch_selection;  // 1 if then branch; 0 if else branch; -1 if undefined
+
+  IfelseState(const IfelseParam &params,
+              const Symbol &cond,
+              const Symbol &then_sym,
+              const Symbol &else_sym):
+              params(params),
+              cond_op(LoopState::MakeSharedOp(cond)),
+              then_branch(then_sym),
+              else_branch(else_sym),
+              branch_selection(-1) {
+  }
+};
+
+static void IfelseComputeExCPU(const OpStatePtr& state_ptr,
+                               const OpContext& ctx,
+                               const std::vector<NDArray>& inputs,
+                               const std::vector<OpReqType>& req,
+                               const std::vector<NDArray>& outputs) {
+  // The argument `inputs' are loop_vars and other inputs
+  // loop_vars are stored in stored in `loop_vars_locs'
+  // The argument `outputs' are output and new_loop_vars
+  // [0: num_out_data) are outputs at each step.
+  // [num_out_data: ) are new_loop_vars
+  IfelseState &state = state_ptr.get_state<IfelseState>();
+  const IfelseParam& params = state.params;
+  // a helper function, converting std::vector<NDArray> to 
std::vector<NDArray*>
+  const auto to_ptr_vec = [](std::vector<NDArray> &in, std::vector<NDArray*> 
*out) {
+    out->clear();
+    out->reserve(in.size());
+    std::transform(std::begin(in),
+                   std::end(in),
+                   std::back_inserter(*out),
+                   [](NDArray &a) {return &a;});
+  };
+  // sanity checks
+  CHECK_EQ(inputs.size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
+  CHECK_EQ(outputs.size(), req.size());
+  // construct inputs and outputs for cond
+  std::vector<NDArray> cond_inputs;
+  std::vector<NDArray> cond_outputs = {NDArray()};
+  std::vector<NDArray*> cond_input_ptr;
+  std::vector<NDArray*> cond_output_ptr;
+  extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
+  to_ptr_vec(cond_inputs, &cond_input_ptr);
+  to_ptr_vec(cond_outputs, &cond_output_ptr);
+  int &branch_selection = state.branch_selection;
+  // run cond
+  state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr);
+  branch_selection = as_bool_scalar(*cond_output_ptr[0]);
+  // select the right branch
+  const nnvm::Tuple<dim_t> &func_input_locs = branch_selection
+                                            ? params.then_input_locs
+                                            : params.else_input_locs;
+  LoopState &loop_state = branch_selection
+                        ? state.then_branch
+                        : state.else_branch;
+  // extract inputs for the branch
+  std::vector<NDArray> func_inputs;
+  extract_by_loc(inputs, func_input_locs, &func_inputs);
+  loop_state.Forward(0, func_inputs, req, outputs, ctx.need_grad);
+}
+
+static void IfelseGradComputeExCPU(const OpStatePtr& state_ptr,
+                                   const OpContext& ctx,
+                                   const std::vector<NDArray>& inputs,
+                                   const std::vector<OpReqType>& _req,
+                                   const std::vector<NDArray>& outputs) {
+  IfelseState &state = state_ptr.get_state<IfelseState>();
+  const IfelseParam& params = state.params;
+  // sanity checks
+  CHECK_EQ(outputs.size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(outputs.size(), _req.size());
+  // select the right branch
+  int branch_selection = state.branch_selection;
+  CHECK_NE(branch_selection, -1);
+  const nnvm::Tuple<dim_t> &func_input_locs = branch_selection
+                                            ? params.then_input_locs
+                                            : params.else_input_locs;
+  LoopState &loop_state = branch_selection
+                        ? state.then_branch
+                        : state.else_branch;
+  // construct parameters
+  std::vector<NDArray> ograds(inputs.begin(), inputs.begin() + 
params.num_outputs);
+  std::vector<OpReqType> req;
+  extract_by_loc(_req, func_input_locs, &req);
+  std::vector<NDArray> igrads;
+  extract_by_loc(outputs, func_input_locs, &igrads);
+  loop_state.Backward(0, ograds, req, igrads);
+  loop_state.Cleanup();
+}
+
+static bool IfelseShape(const nnvm::NodeAttrs& attrs,
+                        std::vector<TShape> *in_shape,
+                        std::vector<TShape> *out_shape) {
+  using nnvm::ShapeVector;
+  const IfelseParam& params = nnvm::get<IfelseParam>(attrs.parsed);
+  static const std::function<bool(const TShape &)> is_udf = is_shape_udf;
+  // sanity checks
+  CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args);
+  CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 3U);
+  CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
+  CHECK_EQ(attrs.subgraphs[1]->outputs.size(), 
attrs.subgraphs[2]->outputs.size());
+  // infer shape for cond, then and else
+  auto infer_subg = [&params, in_shape, out_shape](std::shared_ptr<Symbol> 
subg,
+                                                   ShapeVector *_subg_out,
+                                                   const nnvm::Tuple<dim_t> 
&input_locs,
+                                                   bool fill_out_shape) {
 
 Review comment:
   can you also reuse this function?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to