This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new df9f79a standard update for sparse sgd_mom_update (#9189)
df9f79a is described below
commit df9f79ae5e265e28ceecab3c58828f3a84769eb4
Author: Ziyue Huang <[email protected]>
AuthorDate: Fri Jan 5 13:36:15 2018 +0800
standard update for sparse sgd_mom_update (#9189)
* standard sparse sgd mom update
* update
* update comments
* address comments
* revise
* more general infer stype
* fix
* fix
* add comments for stype inference func
* update
---
python/mxnet/optimizer.py | 25 ++++---
src/operator/optimizer_op-inl.h | 112 ++++++++++++++++++++++++++++++--
src/operator/optimizer_op.cc | 62 +++++++++++++++++-
src/operator/optimizer_op.cu | 66 +++++++++++++++++++
tests/python/unittest/test_optimizer.py | 24 ++++++-
5 files changed, 272 insertions(+), 17 deletions(-)
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 59898c9..feff87e 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -433,14 +433,8 @@ register = Optimizer.register # pylint:
disable=invalid-name
class SGD(Optimizer):
"""The SGD optimizer with momentum and weight decay.
- The optimizer updates the weight by::
-
- rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd *
weight
- state = momentum * state + rescaled_grad
- weight = weight - state
-
- If the storage types of weight, state and grad are all ``row_sparse``, \
- **sparse updates** are applied by::
+ If the storage types of weight and grad are both ``row_sparse``, and
``lazy_update`` is True, \
+ **lazy updates** are applied by::
for row in grad.indices:
rescaled_grad[row] = lr * rescale_grad * clip(grad[row],
clip_gradient) + wd * weight[row]
@@ -454,6 +448,12 @@ class SGD(Optimizer):
provides slightly different semantics than the original update, and
may lead to different empirical results.
+ Otherwise, **standard updates** are applied by::
+
+ rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd *
weight
+ state = momentum * state + rescaled_grad
+ weight = weight - state
+
For details of the update algorithm see
:class:`~mxnet.ndarray.sgd_update` and
:class:`~mxnet.ndarray.sgd_mom_update`.
@@ -464,6 +464,9 @@ class SGD(Optimizer):
----------
momentum : float, optional
The momentum value.
+ lazy_update : bool, optional
+ Default is True. If True, lazy updates are applied \
+ if the storage types of weight and grad are both ``row_sparse``.
multi_precision: bool, optional
Flag to control the internal precision of the optimizer.
``False`` results in using the same precision as the weights (default),
@@ -471,9 +474,10 @@ class SGD(Optimizer):
in 32-bit precision even if actual weights used in the model
have lower precision.\
Turning this on can improve convergence and accuracy when
training with float16.
"""
- def __init__(self, momentum=0.0, **kwargs):
+ def __init__(self, momentum=0.0, lazy_update=True, **kwargs):
super(SGD, self).__init__(**kwargs)
self.momentum = momentum
+ self.lazy_update = lazy_update
def create_state_multi_precision(self, index, weight):
weight_master_copy = None
@@ -489,8 +493,9 @@ class SGD(Optimizer):
def create_state(self, index, weight):
momentum = None
+ stype = weight.stype if self.lazy_update else 'default'
if self.momentum != 0.0:
- momentum = zeros(weight.shape, weight.context, dtype=weight.dtype,
stype=weight.stype)
+ momentum = zeros(weight.shape, weight.context, dtype=weight.dtype,
stype=stype)
return momentum
def _update_impl(self, index, weight, grad, state, multi_precision=False):
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index a6b32b1..33b7dd5 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -38,6 +38,7 @@
#include "./elemwise_op_common.h"
#include "mxnet_op.h"
#include "./tensor/init_op.h"
+#include "./tensor/util/tensor_util-inl.h"
namespace mxnet {
namespace op {
@@ -460,6 +461,106 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam&
param,
mom.data(), req, &out_blob);
}
+/*!
+ * \brief Storge type inference function in optimizer.
+ * \param n_rsp The number of inputs that should be of row_sparse storage
type
+ * if kFComputeEx is dispatched
+ * \param n_rsp_dns The number of inputs that should be of row_sparse or
default storage type
+ * if kFComputeEx is dispatched
+ */
+template<int n_rsp, int n_rsp_dns>
+inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_rsp + n_rsp_dns));
+ CHECK_EQ(out_attrs->size(), 1U);
+ bool dispatched = false;
+
+ if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+ // dns, ... -> dns
+ dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+ dispatch_mode, DispatchMode::kFCompute);
+ }
+ const std::vector<int> rsp_stypes(in_attrs->begin(), in_attrs->begin() +
n_rsp);
+ const std::vector<int> rsp_dns_stypes(in_attrs->begin() + n_rsp,
in_attrs->end());
+ if (!dispatched && common::ContainsOnlyStorage(rsp_stypes,
kRowSparseStorage) &&
+ (common::ContainsOnlyStorage(rsp_dns_stypes, kRowSparseStorage) ||
+ common::ContainsOnlyStorage(rsp_dns_stypes, kDefaultStorage))) {
+ // rsp, ..., rsp/dns, ... -> rsp
+ dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
+ dispatch_mode, DispatchMode::kFComputeEx);
+ }
+
+ if (!dispatched) {
+ dispatch_fallback(out_attrs, dispatch_mode);
+ LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
+ }
+ return true;
+}
+
+template<int req>
+struct SGDMomStdDnsRspDnsKernel {
+ template<typename DType, typename IType, typename RType>
+ MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
+ DType* mom_data, const DType* weight_data, const IType* grad_idx,
+ const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
+ const DType momentum, const DType lr, const DType wd, const DType
rescale_grad) {
+ const DType rate = lr * wd;
+ const bool non_zero = (i == 0) ? prefix_sum[0] > 0
+ : prefix_sum[i] > prefix_sum[i-1];
+
+ const index_t row_i = i * row_length;
+ const RType grad_i = (prefix_sum[i]-1) * row_length;
+ for (index_t j = 0; j < row_length; j++) {
+ const index_t data_i = row_i + j;
+ const DType grad = non_zero ? grad_data[grad_i + j]
+ : static_cast<DType>(0);
+ if (clip_gradient >= 0.0f) {
+ mom_data[data_i] = momentum * mom_data[data_i]
+ - rate * weight_data[data_i]
+ - lr *
+ mshadow_op::clip::Map(rescale_grad * grad,
+ clip_gradient);
+ } else {
+ mom_data[data_i] = momentum * mom_data[data_i]
+ - rate * weight_data[data_i]
+ - lr * rescale_grad * grad;
+ }
+ KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] +
mom_data[data_i]);
+ }
+ }
+};
+
+template<typename xpu>
+void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
+ const OpContext& ctx,
+ const TBlob& weight,
+ const NDArray& grad,
+ const TBlob& mom,
+ const OpReqType& req,
+ TBlob *out);
+
+template<typename xpu>
+inline void SGDMomStdUpdateRspRspDnsImpl(const SGDMomParam& param,
+ const OpContext& ctx,
+ const NDArray& weight,
+ const NDArray& grad,
+ const NDArray& mom,
+ const OpReqType& req,
+ NDArray *out) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ using namespace mxnet_op;
+ using namespace rowsparse;
+ CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ TBlob out_blob = out->data();
+ SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
+ mom.data(), req, &out_blob);
+}
+
template<typename xpu>
inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
@@ -474,12 +575,15 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
const auto weight_stype = weight.storage_type();
const auto mom_stype = mom.storage_type();
const auto out_stype = outputs[0].storage_type();
- CHECK_EQ(weight_stype, mom_stype) << "Inconsistent storage type detected
between mom.stype = "
- << mom_stype << " and weight.stype = " << weight_stype;
+ NDArray out = outputs[0];
if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
out_stype == kRowSparseStorage) {
- NDArray out = outputs[0];
- SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0],
&out);
+ SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0],
&out);
+ } else if (weight.storage_type() == kRowSparseStorage &&
+ grad.storage_type() == kRowSparseStorage &&
+ mom.storage_type() == kDefaultStorage &&
+ out_stype == kRowSparseStorage) {
+ SGDMomStdUpdateRspRspDnsImpl<xpu>(param, ctx, weight, grad, mom, req[0],
&out);
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs,
req, outputs);
}
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 4de94e5..dda8092 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -37,6 +37,57 @@ DMLC_REGISTER_PARAMETER(RMSPropParam);
DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
DMLC_REGISTER_PARAMETER(FtrlParam);
+template<>
+void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
+ const OpContext& ctx,
+ const TBlob& weight,
+ const NDArray& grad,
+ const TBlob& mom,
+ const OpReqType& req,
+ TBlob *out) {
+ using namespace mxnet_op;
+ using namespace rowsparse;
+ using namespace mshadow;
+ Stream<cpu>* s = ctx.get_stream<cpu>();
+ if (req == kNullOp) return;
+ CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse
sgd_mom_update";
+ CHECK_GT(weight.shape_.Size(), 0);
+ CHECK_GT(mom.shape_.Size(), 0);
+ MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+ MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+ MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+ DType* weight_data = weight.dptr<DType>();
+ IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+ DType* grad_val = grad.data().dptr<DType>();
+ DType* mom_data = mom.dptr<DType>();
+ DType* out_data = out->dptr<DType>();
+ nnvm::dim_t num_rows = weight.shape_[0];
+ auto row_length = weight.shape_.ProdShape(1, weight.ndim());
+ Tensor<cpu, 1, char> workspace = ctx.requested[0]
+ .get_space_typed<cpu, 1, char>(Shape1(num_rows *
sizeof(nnvm::dim_t)), s);
+
+ nnvm::dim_t* prefix_sum =
reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
+ // mark row flags
+ Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
+ if (grad.storage_initialized()) {
+ Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
+ prefix_sum, grad_idx);
+ // calculate inclusive prefix sum
+ for (nnvm::dim_t i = 1; i < num_rows; i++) {
+ prefix_sum[i] += prefix_sum[i - 1];
+ }
+ }
+ Kernel<SGDMomStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows,
row_length,
+ out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum,
+ static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.momentum),
+ static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+ static_cast<DType>(param.rescale_grad));
+ });
+ });
+ });
+}
+
+
NNVM_REGISTER_OP(sgd_update)
MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
.describe(R"code(Update function for Stochastic Gradient Descent (SDG)
optimizer.
@@ -84,7 +135,10 @@ It updates the weights using::
Where the parameter ``momentum`` is the decay rate of momentum estimates at
each epoch.
-If weight and momentum are both of ``row_sparse`` storage type,
+If weight and grad are both of ``row_sparse`` storage type and momentum is of
``default`` storage type,
+standard update is applied.
+
+If weight, grad and momentum are all of ``row_sparse`` storage type,
only the row slices whose indices appear in grad.indices are updated (for both
weight and momentum)::
for row in gradient.indices:
@@ -97,11 +151,15 @@ only the row slices whose indices appear in grad.indices
are updated (for both w
.set_attr_parser(ParamParser<SGDMomParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<3, 1,
false, true, false>)
+.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2};
})
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
.set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SGDMomUpdateEx<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu
index 4306b32..9512e92 100644
--- a/src/operator/optimizer_op.cu
+++ b/src/operator/optimizer_op.cu
@@ -24,10 +24,76 @@
* \author Junyuan Xie
*/
#include "./optimizer_op-inl.h"
+#include <cub/cub.cuh>
namespace mxnet {
namespace op {
+template<>
+void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
+ const OpContext& ctx,
+ const TBlob& weight,
+ const NDArray& grad,
+ const TBlob& mom,
+ const OpReqType& req,
+ TBlob *out) {
+ using namespace mxnet_op;
+ using namespace rowsparse;
+ using namespace mshadow;
+ Stream<gpu>* s = ctx.get_stream<gpu>();
+ if (req == kNullOp) return;
+ CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse
sgd_mom_update";
+ CHECK_GT(weight.shape_.Size(), 0);
+ CHECK_GT(mom.shape_.Size(), 0);
+
+ MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
+ MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
+ MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+ DType* weight_data = weight.dptr<DType>();
+ IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
+ DType* grad_val = grad.data().dptr<DType>();
+ DType* mom_data = mom.dptr<DType>();
+ DType* out_data = out->dptr<DType>();
+ nnvm::dim_t num_rows = weight.shape_[0];
+ nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
+
+ nnvm::dim_t* prefix_sum = NULL;
+ void* d_temp_storage = NULL;
+ size_t temp_storage_bytes = 0;
+ cub::DeviceScan::InclusiveSum(d_temp_storage,
+ temp_storage_bytes,
+ prefix_sum,
+ prefix_sum,
+ num_rows,
+ Stream<gpu>::GetStream(s));
+ Tensor<gpu, 1, char> workspace = ctx.requested[0]
+ .get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)
+
+ temp_storage_bytes), s);
+ prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
+ d_temp_storage = workspace.dptr_ + num_rows*sizeof(nnvm::dim_t);
+ // mark row flags
+ Fill<false>(s, TBlob(prefix_sum, Shape1(num_rows), gpu::kDevMask),
kWriteTo, 0);
+ if (grad.storage_initialized()) {
+ Kernel<MarkRowFlgKernel, gpu>::Launch(s, grad.aux_shape(kIdx)[0],
+ prefix_sum, grad_idx);
+ // calculate inclusive prefix sum
+ cub::DeviceScan::InclusiveSum(d_temp_storage,
+ temp_storage_bytes,
+ prefix_sum,
+ prefix_sum,
+ num_rows,
+ mshadow::Stream<gpu>::GetStream(s));
+ }
+ Kernel<SGDMomStdDnsRspDnsKernel<req_type>, gpu>::Launch(s, num_rows,
row_length,
+ out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum,
+ static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.momentum),
+ static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+ static_cast<DType>(param.rescale_grad));
+ });
+ });
+ });
+}
+
NNVM_REGISTER_OP(sgd_update)
.set_attr<FCompute>("FCompute<gpu>", SGDUpdate<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", SGDUpdateEx<gpu>);
diff --git a/tests/python/unittest/test_optimizer.py
b/tests/python/unittest/test_optimizer.py
index 655e157..ae248b0 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -334,6 +334,29 @@ def test_sparse_sgd():
w_stype='row_sparse',
g_stype='row_sparse')
+def test_std_sparse_sgd():
+ mx.random.seed(0)
+ opt1 = PySGD
+ opt2 = mx.optimizer.SGD
+ shape = (3, 4, 5)
+ mom_options = [{'momentum': 0.9}]
+ cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
+ rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
+ wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
+ for dtype in [np.float32]:
+ for mom_option in mom_options:
+ for cg_option in cg_options:
+ for rg_option in rg_options:
+ for wd_option in wd_options:
+ kwarg = {}
+ kwarg.update(mom_option)
+ kwarg.update(cg_option)
+ kwarg.update(rg_option)
+ kwarg.update(wd_option)
+ compare_optimizer(opt1(**kwarg),
opt2(lazy_update=False, **kwarg), shape, dtype,
+ w_stype='row_sparse',
g_stype='row_sparse')
+
+
# FTML
class PyFTML(mx.optimizer.Optimizer):
@@ -400,7 +423,6 @@ def test_ftml():
compare_optimizer(opt1(**kwarg), opt2(**kwarg),
shape, dtype)
-
# ADAM
class PyAdam(mx.optimizer.Optimizer):
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].