eric-haibin-lin commented on a change in pull request #7720: [sparse] add ftrl 
optimizer for sparse
URL: https://github.com/apache/incubator-mxnet/pull/7720#discussion_r137708966
 
 

 ##########
 File path: src/operator/optimizer_op-inl.h
 ##########
 @@ -1035,6 +1035,213 @@ inline void RMSPropUpdate(const nnvm::NodeAttrs 
&attrs, const OpContext &ctx,
   });
 }
 
+struct FtrlParam : public dmlc::Parameter<FtrlParam> {
+  float lr;
+  float lamda1;
+  float beta;
+  float wd;
+  float rescale_grad;
+  float clip_gradient;
+  DMLC_DECLARE_PARAMETER(FtrlParam) {
+    DMLC_DECLARE_FIELD(lr)
+    .describe("Learning rate");
+    DMLC_DECLARE_FIELD(lamda1)
+    .set_default(0.01f)
+    .describe("The L1 regularization coefficient.");
+    DMLC_DECLARE_FIELD(beta)
+    .set_default(1.0f)
+    .describe("Per-Coordinate Learning Rate beta.");
+    DMLC_DECLARE_FIELD(wd)
+    .set_default(0.0f)
+    .describe("Weight decay augments the objective function with a "
+              "regularization term that penalizes large weights. "
+              "The penalty scales with the square of the magnitude of each 
weight.");
+    DMLC_DECLARE_FIELD(rescale_grad)
+    .set_default(1.0f)
+    .describe("Rescale gradient to grad = rescale_grad*grad.");
+    DMLC_DECLARE_FIELD(clip_gradient)
+    .set_default(-1.0f)
+    .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
+              "If clip_gradient <= 0, gradient clipping is turned off. "
+              "grad = max(min(grad, clip_gradient), -clip_gradient).");
+  }
+};
+
+template<typename xpu>
+inline void FtrlUpdate(const nnvm::NodeAttrs& attrs,
+                       const OpContext &ctx,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mshadow_op;
+  const FtrlParam& param = nnvm::get<FtrlParam>(attrs.parsed);
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> z = inputs[2].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> n = inputs[3].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
+
+    grad = scalar<DType>(param.rescale_grad) * grad;
+
+    if (param.clip_gradient >= 0.0f) {
+      z += F<clip>(grad, DType(param.clip_gradient)) - (F<square_root>(n +
+           F<square>(F<clip>(grad, DType(param.clip_gradient)))) - 
F<square_root>(n)) *
+           weight / scalar<DType>(param.lr);
+      n += F<square>(F<clip>(grad, DType(param.clip_gradient)));
+    } else {
+      z += grad - (F<square_root>(n + F<square>(grad)) - F<square_root>(n)) *
+           weight / scalar<DType>(param.lr);
+      n += F<square>(grad);
+    }
+    Assign(out, req[0],
+           (F<sign>(z) * scalar<DType>(param.lamda1) - z) /
+           ((scalar<DType>(param.beta) + F<square_root>(n)) /
+           scalar<DType>(param.lr) + scalar<DType>(param.wd)) *
+           F<gt>(F<abs>(z), DType(param.lamda1)));
+  });
+}
+
+template<int req>
+struct FtrlDnsRspDnsKernel {
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* 
out_data,
+    DType* z_data, DType* n_data, const DType* weight_data, const IType* 
grad_idx,
+    const DType* grad_data, const DType clip_gradient, const DType lamda1, 
const DType beta,
+    const DType lr, const DType wd, const DType rescale_grad) {
+    using nnvm::dim_t;
+    using namespace mshadow_op;
+    const dim_t row_offset = grad_idx[i] * row_length;
+    for (dim_t j = 0; j < row_length; j++) {
+      // index in data/z/n
+      const dim_t data_i = row_offset + j;
+      // index in grad
+      const dim_t grad_i = i * row_length + j;
+      const DType grad_rescaled = grad_data[grad_i] * rescale_grad;
+      if (clip_gradient >= 0.0f) {
+        z_data[data_i] += clip::Map(grad_rescaled, clip_gradient) -
+                          (square_root::Map(n_data[data_i] +
+                          square::Map(clip::Map(grad_rescaled, 
clip_gradient))) -
+                          square_root::Map(n_data[data_i])) * 
weight_data[data_i] / lr;
+        n_data[data_i] += square::Map(clip::Map(grad_rescaled, clip_gradient));
+      } else {
+        z_data[data_i] += grad_rescaled - (square_root::Map(n_data[data_i] +
+                          square::Map(grad_rescaled)) - 
square_root::Map(n_data[data_i])) *
+                          weight_data[data_i] / lr;
+        n_data[data_i] += square::Map(grad_rescaled);
+      }
+      KERNEL_ASSIGN(out_data[data_i], req,
+                    (sign::Map(z_data[data_i]) * lamda1 - z_data[data_i]) /
+                    ((beta + square_root::Map(n_data[data_i])) / lr + wd) *
+                    gt::Map(abs::Map(z_data[data_i]), lamda1));
+    }
+  }
+};
+
+
+template<typename xpu>
+inline void FtrlUpdateDnsRspDnsImpl(const FtrlParam& param,
+                                    const OpContext& ctx,
+                                    const TBlob& weight,
+                                    const NDArray& grad,
+                                    const TBlob& z,
+                                    const TBlob& n,
+                                    const OpReqType& req,
+                                    TBlob *out) {
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  if (!grad.storage_initialized() || req == kNullOp) return;
+  CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse 
ftrl_update";
+  CHECK_GT(weight.shape_.Size(), 0);
+  CHECK_GT(z.shape_.Size(), 0);
+  CHECK_GT(n.shape_.Size(), 0);
+
+  MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        const DType* weight_data = weight.dptr<DType>();
+        const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+        const DType* grad_val = grad.data().dptr<DType>();
+        DType* z_data = z.dptr<DType>();
+        DType* n_data = n.dptr<DType>();
+        DType* out_data = out->dptr<DType>();
+        nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0];
+        const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
+        Kernel<FtrlDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, 
row_length,
+          out_data, z_data, n_data, weight_data, grad_idx, grad_val,
+          static_cast<DType>(param.clip_gradient), 
static_cast<DType>(param.lamda1),
+          static_cast<DType>(param.beta), static_cast<DType>(param.lr),
+          static_cast<DType>(param.wd), 
static_cast<DType>(param.rescale_grad));
+      });
+    });
+  });
+}
+
+template<typename xpu>
+inline void FtrlUpdateRspRspRspImpl(const FtrlParam& param,
+                                    const OpContext& ctx,
+                                    const NDArray& weight,
+                                    const NDArray& grad,
+                                    const NDArray& z,
+                                    const NDArray& n,
+                                    const OpReqType& req,
+                                    NDArray *out) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  using namespace rowsparse;
+  CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "FtrlUpdate", "weights");
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  // fill mean and variance with zero values in order to reuse the sgd mom dns 
impl
 
 Review comment:
   pls update the comment accordingly for ftrl
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to