eric-haibin-lin closed pull request #11355: Enable support for dense weight and
sparse grad Adagrad updates
URL: https://github.com/apache/incubator-mxnet/pull/11355
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 0c3fc904fb1..267a402f246 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -1107,7 +1107,7 @@ def update(self, index, weight, grad, state):
lr = self._get_lr(index)
wd = self._get_wd(index)
- is_sparse = weight.stype == 'row_sparse' and grad.stype == 'row_sparse'
+ is_sparse = grad.stype == 'row_sparse'
history = state
if is_sparse:
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 28b382c92fb..9251b861480 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -1663,16 +1663,20 @@ inline bool AdagradStorageType(const nnvm::NodeAttrs&
attrs,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
+ const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
- const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
+ const int weight_stype = in_attrs->at(0);
+ const int grad_stype = in_attrs->at(1);
+ const int state_stype = in_attrs->at(2);
bool dispatched = false;
- if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage)
&&
- common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) &&
- param.wd == 0.0f) {
- // rsp, rsp, rsp -> rsp with wd = 0.0
- dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
- dispatch_mode, DispatchMode::kFComputeEx);
+ if (!dispatched && grad_stype == kRowSparseStorage &&
+ (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
+ state_stype == weight_stype && param.wd == 0.0f) {
+ // weight and state share stype, grad's stype = rsp
+ dispatched = storage_type_assign(
+ out_attrs, static_cast<NDArrayStorageType>(weight_stype),
dispatch_mode,
+ DispatchMode::kFComputeEx);
}
return dispatched;
}
@@ -1802,10 +1806,24 @@ inline void AdagradUpdateEx(const nnvm::NodeAttrs&
attrs,
const std::vector<NDArray> &outputs) {
using namespace mxnet_op;
const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
+
+ const auto weight_stype = inputs[0].storage_type();
+ const auto grad_stype = inputs[1].storage_type();
+ const auto state_stype = inputs[2].storage_type();
+ const auto output_stype = outputs[0].storage_type();
+
if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
common::ContainsOnlyStorage(outputs, kRowSparseStorage)) {
NDArray out = outputs[0];
- AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1],
inputs[2], req[0], &out);
+ AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1],
inputs[2],
+ req[0], &out);
+ } else if (state_stype == weight_stype && output_stype == weight_stype &&
+ weight_stype == kDefaultStorage &&
+ grad_stype == kRowSparseStorage) {
+ TBlob out_blob = outputs[0].data();
+ AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1],
+ inputs[2].data(), req[0],
+ &out_blob);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
diff --git a/tests/python/unittest/test_optimizer.py
b/tests/python/unittest/test_optimizer.py
index fba10fb522a..a5b3d4047df 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -1034,6 +1034,8 @@ def test_adagrad():
if wd_option.get('wd', 0.0) == 0.0:
compare_optimizer(opt1(**kwarg), opt2(**kwarg),
shape, dtype,
w_stype='row_sparse',
g_stype='row_sparse')
+ compare_optimizer(opt1(**kwarg), opt2(**kwarg),
shape, dtype,
+ g_stype='row_sparse')
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services