piiswrong closed pull request #11001: [MXNET-374] handle row_sparse weight in
parameter and trainer
URL: https://github.com/apache/incubator-mxnet/pull/11001
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 7abe767c869..10bca17b5ff 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -544,6 +544,7 @@ integrationtest_ubuntu_gpu_dist_kvstore() {
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py
--no-multiprecision
../../tools/launch.py -n 7 --launcher local python
dist_device_sync_kvstore.py
+ ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py
--type=gluon
}
test_ubuntu_cpu_python2() {
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index dbe3c5e032b..4b37f4328e7 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -606,6 +606,7 @@ class HybridBlock(Block):
Refer `Hybrid tutorial <http://mxnet.io/tutorials/gluon/hybrid.html>`_ to
see
the end-to-end usage.
+
"""
def __init__(self, prefix=None, params=None):
super(HybridBlock, self).__init__(prefix=prefix, params=params)
@@ -879,6 +880,14 @@ def __init__(self, outputs, inputs, params=None):
"Input symbols must be variable, but %s is an output of
operators"%str(i)
input_names.add(i.name)
+ # check if any symbol is row_sparse
+ row_sparse_storage =
ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse']
+ for i in out:
+ for j in i.get_internals():
+ assert(j.attr("__storage_type__") != str(row_sparse_storage)),
\
+ "SymbolBlock doesn't support Parameter '%s' because its
storage " \
+ "type is 'row_sparse'." % j.name
+
for i in out.list_arguments():
if i not in input_names:
self.params.get(i, allow_deferred_init=True)
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index c7cbcccc95e..3265fef2b6c 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -81,6 +81,8 @@ class Parameter(object):
Weight decay multiplier (L2 regularizer coefficient). Works similar to
lr_mult.
init : Initializer, default None
Initializer of this parameter. Will use the global initializer by
default.
+ stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
+ The storage type of the parameter.
grad_stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
The storage type of the parameter's gradient.
@@ -99,12 +101,13 @@ class Parameter(object):
"""
def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
lr_mult=1.0, wd_mult=1.0, init=None,
allow_deferred_init=False,
- differentiable=True, grad_stype='default'):
+ differentiable=True, stype='default', grad_stype='default'):
self._var = None
self._data = None
self._grad = None
self._ctx_list = None
self._ctx_map = None
+ self._trainer = None
self._deferred_init = ()
self._differentiable = differentiable
self._allow_deferred_init = allow_deferred_init
@@ -116,10 +119,14 @@ def __init__(self, name, grad_req='write', shape=None,
dtype=mx_real_t,
self.wd_mult = wd_mult
self.grad_req = grad_req
self.init = init
- assert grad_stype in ['default', 'row_sparse', 'csr'], \
- "grad_stype for Parameter '%s' must be one of 'default',
'row_sparse', or 'csr'," \
- " but got '%s'" % (name, grad_stype)
+ # sparse related storage type information
+ valid_stypes = ['default', 'row_sparse', 'csr']
+ assert grad_stype in valid_stypes, "grad_stype for Parameter '%s' must
be " \
+ "one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name,
grad_stype)
+ assert stype in valid_stypes, "stype for Parameter '%s' must be " \
+ "one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name,
stype)
self._grad_stype = grad_stype
+ self._stype = stype
def __repr__(self):
@@ -162,6 +169,16 @@ def shape(self, new_shape):
self._shape = new_shape
+ def _set_trainer(self, trainer):
+ """ Set the trainer this parameter is associated with. """
+ # trainer cannot be replaced for sparse params
+ if self._stype != 'default' and self._trainer and trainer and
self._trainer is not trainer:
+ raise RuntimeError(
+ "Failed to set the trainer for Parameter '%s' because it was
already set. " \
+ "More than one trainers for a %s Parameter is not supported." \
+ %(self.name, self._stype))
+ self._trainer = trainer
+
def _check_and_get(self, arr_list, ctx):
if arr_list is not None:
if ctx is list:
@@ -194,6 +211,20 @@ def _check_and_get(self, arr_list, ctx):
"because the later does not include Parameters of " \
"nested child Blocks"%(self.name))
+ def _get_row_sparse(self, arr_list, ctx, row_id):
+ """ Get row_sparse data from row_sparse parameters based on row_id. """
+ # get row sparse params based on row ids
+ if not isinstance(row_id, ndarray.NDArray):
+ raise TypeError("row_id must have NDArray type, but %s is
given"%(type(row_id)))
+ if not self._trainer:
+ raise RuntimeError("Cannot get row_sparse data for Parameter '%s'
when no " \
+ "Trainer is created with it."%self.name)
+ results = self._check_and_get(arr_list, ctx)
+
+ # fetch row sparse params from the trainer
+ self._trainer._row_sparse_pull(self, results, row_id)
+ return results
+
def _load_init(self, data, ctx):
"""(Re)initializes by loading from data."""
if self.shape:
@@ -208,6 +239,8 @@ def _load_init(self, data, ctx):
"Failed loading Parameter '%s' from saved params: " \
"dtype incompatible expected %s vs saved %s"%(
self.name, str(self.dtype), str(data.dtype))
+ if self._stype != data.stype:
+ data = data.tostype(self._stype)
if isinstance(ctx, Context):
ctx = [ctx]
if self._data is None:
@@ -243,7 +276,7 @@ def _finish_deferred_init(self):
with autograd.pause():
if data is None:
data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
- ctx=context.cpu())
+ ctx=context.cpu(), stype=self._stype)
initializer.create(default_init)(
initializer.InitDesc(self.name, {'__init__': init}), data)
@@ -271,12 +304,18 @@ def _init_grad(self):
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype,
ctx=i.context,
stype=self._grad_stype) for i in
self._data]
- autograd.mark_variables(self.list_data(), self.list_grad(),
self.grad_req)
+ autograd.mark_variables(self._check_and_get(self._data, list),
+ self._grad, self.grad_req)
def _reduce(self):
"""Reduce data from multiple context."""
- block = self.list_data()
- data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) /
len(block)
+ if self._stype == 'default':
+ block = self.list_data()
+ data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) /
len(block)
+ else:
+ # fetch all rows for 'row_sparse' param
+ all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64',
ctx=context.cpu())
+ data = self.row_sparse_data(all_row_ids)
return data
def initialize(self, init=None, ctx=None,
default_init=initializer.Uniform(),
@@ -380,12 +419,58 @@ def set_data(self, data):
self._deferred_init = self._deferred_init[:3] + (data,)
return
- for arr in self.list_data():
+ # if update_on_kvstore, we need to make sure the copy stored in
kvstore is in sync
+ if self._trainer and self._trainer._kv_initialized and
self._trainer._update_on_kvstore:
+ if self not in self._trainer._params_to_init:
+ self._trainer._reset_kvstore()
+
+ for arr in self._check_and_get(self._data, list):
arr[:] = data
+ def row_sparse_data(self, row_id):
+ """Returns a copy of the 'row_sparse' parameter on the same context as
row_id's.
+ The copy only retains rows whose ids occur in provided row ids.
+ The parameter must have been initialized on this context before.
+
+ Parameters
+ ----------
+ row_id: NDArray
+ Row ids to retain for the 'row_sparse' parameter.
+
+ Returns
+ -------
+ NDArray on row_id's context
+ """
+ if self._stype != 'row_sparse':
+ raise RuntimeError("Cannot return a copy of Parameter %s via
row_sparse_data() " \
+ "because its storage type is %s. Please use
data() instead." \
+ %(self.name, self._stype))
+ return self._get_row_sparse(self._data, row_id.context, row_id)
+
+ def list_row_sparse_data(self, row_id):
+ """Returns copies of the 'row_sparse' parameter on all contexts, in
the same order
+ as creation. The copy only retains rows whose ids occur in provided
row ids.
+ The parameter must have been initialized before.
+
+ Parameters
+ ----------
+ row_id: NDArray
+ Row ids to retain for the 'row_sparse' parameter.
+
+ Returns
+ -------
+ list of NDArrays
+ """
+ if self._stype != 'row_sparse':
+ raise RuntimeError("Cannot return copies of Parameter '%s' on all
contexts via " \
+ "list_row_sparse_data() because its storage
type is %s. Please " \
+ "use data() instead." % (self.name,
self._stype))
+ return self._get_row_sparse(self._data, list, row_id)
+
def data(self, ctx=None):
"""Returns a copy of this parameter on one context. Must have been
- initialized on this context before.
+ initialized on this context before. For sparse parameters, use
+ :py:meth:`Parameter.row_sparse_data` instead.
Parameters
----------
@@ -396,11 +481,25 @@ def data(self, ctx=None):
-------
NDArray on ctx
"""
+ if self._stype != 'default':
+ raise RuntimeError("Cannot return a copy of Parameter '%s' on ctx
%s via data() " \
+ "because its storage type is %s. Please use
row_sparse_data() " \
+ "instead." % (self.name, str(ctx), self._stype))
return self._check_and_get(self._data, ctx)
def list_data(self):
"""Returns copies of this parameter on all contexts, in the same order
- as creation."""
+ as creation. For sparse parameters, use
:py:meth:`Parameter.list_row_sparse_data`
+ instead.
+
+ Returns
+ -------
+ list of NDArrays
+ """
+ if self._stype != 'default':
+ raise RuntimeError("Cannot return copies of Parameter '%s' on all
contexts via " \
+ "list_data() because its storage type is %s.
Please use " \
+ "row_sparse_data() instead." % (self.name,
self._stype))
return self._check_and_get(self._data, list)
def grad(self, ctx=None):
@@ -447,7 +546,7 @@ def var(self):
if self._var is None:
self._var = symbol.var(self.name, shape=self.shape,
dtype=self.dtype,
lr_mult=self.lr_mult, wd_mult=self.wd_mult,
- init=self.init)
+ init=self.init, stype=self._stype)
return self._var
def cast(self, dtype):
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index f285b9187e8..ef20109021a 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -21,7 +21,7 @@
__all__ = ['Trainer']
from .. import optimizer as opt
-from ..model import _create_kvstore
+from ..model import _create_kvstore, _create_sparse_kvstore
from .parameter import ParameterDict, Parameter
class Trainer(object):
@@ -68,20 +68,30 @@ def __init__(self, params, optimizer,
optimizer_params=None, kvstore='device',
"First argument must be a list or dict of Parameters, " \
"got %s."%(type(params)))
self._params = []
- for param in params:
+ # parameters to initialize on the kvstore
+ self._contains_sparse = False
+ self._param2idx = {}
+ for i, param in enumerate(params):
if not isinstance(param, Parameter):
raise ValueError(
"First argument must be a list or dict of Parameters, " \
"got list of %s."%(type(param)))
+ self._param2idx[param.name] = i
self._params.append(param)
+ param._set_trainer(self)
+ if param._stype != 'default':
+ self._contains_sparse = True
self._compression_params = compression_params
optimizer_params = optimizer_params if optimizer_params else {}
self._scale = float(optimizer_params.get('rescale_grad', 1.0))
self._contexts = self._check_contexts()
self._init_optimizer(optimizer, optimizer_params)
+ self._kvstore_params = {'kvstore': kvstore, 'update_on_kvstore':
update_on_kvstore}
self._kv_initialized = False
- self._kvstore = kvstore
- self._update_on_kvstore = update_on_kvstore
+ self._kvstore = None
+ self._update_on_kvstore = None
+ self._params_to_init = []
+ self._reset_kvstore()
def _check_contexts(self):
contexts = None
@@ -109,38 +119,62 @@ def _init_optimizer(self, optimizer, optimizer_params):
self._updaters = [opt.get_updater(self._optimizer) \
for _ in self._contexts]
+ def _init_params(self):
+ """Initialize parameters in the KVStore.
+
+ Parameters with incomplete initialization are ignored.
+
+ """
+ assert self._kv_initialized, "Cannot initialize parameters in KVStore
" \
+ "when KVStore is not initialized."
+ params_to_init = []
+ if self._kvstore:
+ for param in self._params_to_init:
+ if param._deferred_init:
+ params_to_init.append(param)
+ else:
+ param_arrays = param._check_and_get(param._data, list)
+ idx = self._param2idx[param.name]
+ self._kvstore.init(idx, param_arrays[0])
+ if param._stype == 'default':
+ self._kvstore.pull(idx, param_arrays, priority=-idx)
+
+ self._params_to_init = params_to_init
+
+ def _reset_kvstore(self):
+ """Reset kvstore."""
+ if self._kvstore and 'dist' in self._kvstore.type:
+ raise RuntimeError("Cannot reset distributed KVStore.")
+ self._kv_initialized = False
+ self._kvstore = None
+ self._update_on_kvstore = None
+ self._params_to_init = [param for param in self._params]
+
def _init_kvstore(self):
+ """Create kvstore."""
arg_arrays = {}
- contains_sparse = False
- for param in self._params:
- arg_arrays[param.name] = param.data(self._contexts[0])
- if param._grad_stype != 'default':
- contains_sparse = True
- # update_on_kvstore is set to False by the user
- if self._update_on_kvstore is False:
- raise RuntimeError("Cannot set update_on_kvstore to False
when sparse "
- "gradients and/or sparse weights are
present for "
- "Parameter %s." % param.name)
- kvstore, update_on_kvstore = _create_kvstore(self._kvstore,
len(self._contexts),
- arg_arrays)
- update_on_kvstore = self._update_on_kvstore if self._update_on_kvstore
is not None \
- else update_on_kvstore
+ config = self._kvstore_params
+ if self._contains_sparse:
+ kvstore, update_on_kvstore =
_create_sparse_kvstore(config['kvstore'])
+ # update_on_kvstore is set to False by the user
+ if config['update_on_kvstore'] is False:
+ raise RuntimeError("Cannot set update_on_kvstore to False when
sparse "
+ "gradients and/or sparse weights are
present for "
+ "Parameter '%s'."%param.name)
+ else:
+ kvstore, update_on_kvstore = _create_kvstore(config['kvstore'],
len(self._contexts),
+ arg_arrays)
+ if config['update_on_kvstore'] is not None:
+ update_on_kvstore = config['update_on_kvstore']
if kvstore:
if self._compression_params:
kvstore.set_gradient_compression(self._compression_params)
# kv.pull(row_sparse_grad) is not supported
- if contains_sparse:
- update_on_kvstore = True
- else:
- if 'dist' in kvstore.type:
- update_on_kvstore = False
+ if 'dist' in kvstore.type and not self._contains_sparse:
+ update_on_kvstore = False
if update_on_kvstore:
+ # optimizer preferably needs to be set before init for
multiprecision
kvstore.set_optimizer(self._optimizer)
- # optimizer preferably needs to be set before init for
multiprecision
- for i, param in enumerate(self._params):
- param_arrays = param.list_data()
- kvstore.init(i, param_arrays[0])
- kvstore.pull(i, param_arrays, priority=-i)
self._kvstore = kvstore
self._update_on_kvstore = update_on_kvstore
else:
@@ -171,6 +205,15 @@ def set_learning_rate(self, lr):
else:
self._optimizer.set_learning_rate(lr)
+ def _row_sparse_pull(self, parameter, out, row_id):
+ # initialize kv and params if not already
+ if not self._kv_initialized:
+ self._init_kvstore()
+ if self._params_to_init:
+ self._init_params()
+ self._kvstore.row_sparse_pull(self._param2idx[parameter.name], \
+ out=out, row_ids=row_id)
+
def step(self, batch_size, ignore_stale_grad=False):
"""Makes one step of parameter update. Should be called after
`autograd.backward()` and outside of `record()` scope.
@@ -191,6 +234,8 @@ def step(self, batch_size, ignore_stale_grad=False):
"""
if not self._kv_initialized:
self._init_kvstore()
+ if self._params_to_init:
+ self._init_params()
self._optimizer.rescale_grad = self._scale / batch_size
@@ -210,6 +255,8 @@ def allreduce_grads(self):
"""
if not self._kv_initialized:
self._init_kvstore()
+ if self._params_to_init:
+ self._init_params()
assert not (self._kvstore and self._update_on_kvstore), \
'allreduce_grads() when parameters are updated on kvstore ' \
'is not supported. Try setting `update_on_kvstore` ' \
@@ -250,6 +297,8 @@ def update(self, batch_size, ignore_stale_grad=False):
"""
if not self._kv_initialized:
self._init_kvstore()
+ if self._params_to_init:
+ self._init_params()
assert not (self._kvstore and self._update_on_kvstore), \
'update() when parameters are updated on kvstore ' \
'is not supported. Try setting `update_on_kvstore` ' \
@@ -264,7 +313,7 @@ def _update(self, ignore_stale_grad=False):
continue
if not ignore_stale_grad:
- for data in param.list_data():
+ for data in param._check_and_get(param._data, list):
if not data._fresh_grad:
raise UserWarning(
"Gradient of Parameter `%s` on context %s has not
been updated "
@@ -276,7 +325,10 @@ def _update(self, ignore_stale_grad=False):
%(param.name, str(data.context)))
if self._kvstore and self._update_on_kvstore:
- self._kvstore.pull(i, param.list_data(), priority=-i)
+ if param._stype == 'default':
+ # 'row_sparse' parameters are not pulled immediately -
they're pulled
+ # in `SparseBlock.sparse_forward`
+ self._kvstore.pull(i, param.list_data(), priority=-i)
continue
for upd, arr, grad in zip(self._updaters, param.list_data(),
param.list_grad()):
@@ -296,8 +348,12 @@ def save_states(self, fname):
if not self._kv_initialized:
self._init_kvstore()
+ if self._params_to_init:
+ self._init_params()
if self._update_on_kvstore:
+ assert not self._params_to_init, "Cannot save trainer states when
some " \
+ "parameters are not yet
initialized in kvstore."
self._kvstore.save_optimizer_states(fname, dump_optimizer=True)
else:
with open(fname, 'wb') as fout:
@@ -313,6 +369,8 @@ def load_states(self, fname):
"""
if not self._kv_initialized:
self._init_kvstore()
+ if self._params_to_init:
+ self._init_params()
if self._update_on_kvstore:
self._kvstore.load_optimizer_states(fname)
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index ae7726d76a7..3a50553a615 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -55,6 +55,25 @@
'eval_metric',
'locals'])
+def _create_sparse_kvstore(kvstore):
+ """Create kvstore assuming some parameters' storage types are row_sparse.
+
+ Parameters
+ ----------
+ kvstore : KVStore or str
+ The kvstore.
+ """
+ # always update on kvstore
+ update_on_kvstore = True
+ if isinstance(kvstore, kvs.KVStore):
+ kv = kvstore
+ elif isinstance(kvstore, str):
+ kv = kvs.create(kvstore)
+ else:
+ raise TypeError("Cannot create '%s' KVStore with row_sparse
parameters. "
+ "The type must be KVStore or str." % kvstore)
+ return (kv, update_on_kvstore)
+
def _create_kvstore(kvstore, num_device, arg_params):
"""Create kvstore
This function select and create a proper kvstore if given the kvstore type.
diff --git a/src/operator/tensor/indexing_op.h
b/src/operator/tensor/indexing_op.h
index 28827db0e63..23a866d75af 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -270,6 +270,12 @@ inline bool EmbeddingOpBackwardStorageType(const
nnvm::NodeAttrs& attrs,
dispatched = dispatch_mode_assign(dispatch_mode, target_mode);
}
}
+ // Print user friendly error message to notify misuses of sparse_grad
+ if (weight_grad_stype != target_stype) {
+ LOG(FATAL) << "Cannot use sparse_grad = " << sparse_grad
+ << ", while stype of gradients w.r.t embedding weight is "
+ << common::stype_string(weight_grad_stype);
+ }
return dispatched;
}
diff --git a/tests/nightly/dist_sync_kvstore.py
b/tests/nightly/dist_sync_kvstore.py
index 3bf5cbffa13..32ed2dddb6f 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -24,7 +24,7 @@
import mxnet as mx
import numpy as np
import numpy.random as rnd
-from mxnet.test_utils import assert_almost_equal
+from mxnet.test_utils import assert_almost_equal, assert_exception
from test_kvstore import compute_expected_2bit_quantization
def check_diff(A, x, rank=None):
@@ -350,6 +350,20 @@ def check_init(kv, cur_keys, cur_shape, device=False):
check_init(kv, init_test_keys_device_big, big_shape, device=True)
print('worker ' + str(kv.rank) + ' is initialized')
+def test_gluon_trainer_reset():
+ params = mx.gluon.ParameterDict()
+ x = params.get('x', shape=(4, 2), lr_mult=1.0, stype='row_sparse')
+ params.initialize(ctx=mx.cpu(0), init='zeros')
+ trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1},
kvstore=kv)
+ params.save('test_gluon_trainer_reset_' + str(my_rank) + '.params')
+ row_id = mx.nd.arange(0, 4)
+ w = x.row_sparse_data(row_id)
+ assert trainer._kv_initialized and trainer._update_on_kvstore
+ # load would fail to reset kvstore since update_on_kvstore is True
+ assert_exception(params.load, RuntimeError, 'test_gluon_trainer_reset_' +
str(my_rank) + '.params')
+ print('worker ' + str(my_rank) + ' passed test_gluon_trainer_reset')
+
+
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='test distributed kvstore in
dist_sync mode')
parser.add_argument('--nrepeat', type=int, default=7)
@@ -357,13 +371,16 @@ def check_init(kv, cur_keys, cur_shape, device=False):
parser.add_argument('--no-gpu', dest='gpu', action='store_false')
parser.add_argument('--no-multiprecision', dest='multiprecision',
action='store_false')
opt = parser.parse_args()
- if opt.type == 'all' or opt.type == 'init':
+ if opt.type == 'gluon':
+ test_gluon_trainer_reset()
+ if opt.type == 'all' or opt.type == 'init':
test_sync_init(opt.gpu)
- kv = init_kv()
- if opt.type == 'all' or opt.type == 'default':
+ if opt.type == 'all' or opt.type == 'default':
+ kv = init_kv()
kv = set_optimizer(use_multiprecision=opt.multiprecision)
test_sync_push_pull(opt.nrepeat)
# dont run non compressed tests after this as kvstore compression will be
set here
- if opt.type == 'all' or opt.type == 'compressed':
+ if opt.type == 'all' or opt.type == 'compressed':
+ kv = init_kv()
kv, threshold = init_kv_compressed(kv)
test_sync_2bit_compression(threshold, opt.nrepeat)
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index b1b5fe2fe50..2384812ef64 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -19,7 +19,8 @@
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.test_utils import assert_almost_equal
-from common import setup_module, with_seed
+from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
+from common import setup_module, with_seed, assertRaises
import numpy as np
from nose.tools import raises, assert_raises
from copy import deepcopy
@@ -27,8 +28,6 @@
import json
import unittest
-
-
@with_seed()
def test_parameter():
p = gluon.Parameter('weight', shape=(10, 10))
@@ -39,33 +38,122 @@ def test_parameter():
assert p.data(mx.cpu(0)).shape == (10, 10)
assert p.var().name == 'weight'
assert p.grad(mx.cpu(0)).stype == 'default'
+ assert p.data(mx.cpu(0)).stype == 'default'
p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)])
assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)]
+@with_seed()
+@raises(AssertionError)
+def test_invalid_parameter_stype():
+ p = gluon.Parameter('weight', shape=(10, 10), stype='invalid')
+
+@with_seed()
+@raises(AssertionError)
+def test_invalid_parameter_grad_stype():
+ p = gluon.Parameter('weight', shape=(10, 10), grad_stype='invalid')
+
@with_seed()
def test_sparse_parameter():
- p = gluon.Parameter('weight', shape=(10, 10), grad_stype='row_sparse')
+ p = gluon.Parameter('weight', shape=(10, 10), stype='row_sparse',
grad_stype='row_sparse')
p.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
- assert len(p.list_data()) == 2
+ row_id = mx.nd.arange(0, 10, ctx=mx.cpu(1))
assert len(p.list_grad()) == 2
- assert p.data(mx.cpu(1)).context == mx.cpu(1)
- assert p.data(mx.cpu(0)).shape == (10, 10)
+ # getting row_sparse data without trainer throws an exception
+ assertRaises(RuntimeError, p.list_row_sparse_data, row_id)
+ trainer = mx.gluon.Trainer([p], 'sgd')
+ assert len(p.list_row_sparse_data(row_id)) == 2
+ weight = p.row_sparse_data(row_id)
+ assert weight.context == mx.cpu(1)
+ assert weight.shape == (10, 10)
+ assert weight.stype == 'row_sparse'
assert p.var().name == 'weight'
+ assert p.var().attr('__storage_type__') ==
str(_STORAGE_TYPE_STR_TO_ID['row_sparse'])
assert p.grad(mx.cpu(0)).stype == 'row_sparse'
p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)])
assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)]
+@with_seed()
+def test_parameter_invalid_access():
+ # cannot call data on row_sparse parameters
+ p0 = gluon.Parameter('weight', shape=(10, 10), stype='row_sparse',
grad_stype='row_sparse')
+ p0.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
+ assertRaises(RuntimeError, p0.data)
+ assertRaises(RuntimeError, p0.list_data)
+ row_id = mx.nd.arange(0, 10)
+ # cannot call row_sparse_data on dense parameters
+ p1 = gluon.Parameter('weight', shape=(10, 10))
+ p1.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
+ assertRaises(RuntimeError, p1.row_sparse_data, row_id.copyto(mx.cpu(0)))
+ assertRaises(RuntimeError, p1.list_row_sparse_data, row_id)
@with_seed()
def test_paramdict():
- params = gluon.ParameterDict('net_')
- params.get('weight', shape=(10, 10))
- assert list(params.keys()) == ['net_weight']
- params.initialize(ctx=mx.cpu())
- params.save('test.params')
- params.load('test.params', mx.cpu())
+ params0 = gluon.ParameterDict('net_')
+ params0.get('w0', shape=(10, 10))
+ params0.get('w1', shape=(10, 10), stype='row_sparse')
+ all_row_ids = mx.nd.arange(0, 10, ctx=mx.cpu())
+ # check param names
+ assert list(params0.keys()) == ['net_w0', 'net_w1']
+ params0.initialize(ctx=mx.cpu())
+ trainer0 = mx.gluon.Trainer(params0, 'sgd')
+ prev_w0 = params0.get('w0').data(mx.cpu())
+ prev_w1 = params0.get('w1').row_sparse_data(all_row_ids)
+ # save params
+ params0.save('test_paramdict.params')
+
+ # load params
+ params1 = gluon.ParameterDict('net_')
+ params1.get('w0', shape=(10, 10))
+ params1.get('w1', shape=(10, 10), stype='row_sparse')
+ params1.load('test_paramdict.params', mx.cpu())
+ trainer1 = mx.gluon.Trainer(params1, 'sgd')
+
+ # compare the values before and after save/load
+ cur_w0 = params1.get('w0').data(mx.cpu())
+ cur_w1 = params1.get('w1').row_sparse_data(all_row_ids)
+ mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
+ mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())
+
+ # create a new param dict with dense params, and load from the checkpoint
+ # of sparse & dense params
+ params2 = gluon.ParameterDict('net_')
+ params2.get('w0', shape=(10, 10))
+ params2.get('w1', shape=(10, 10))
+ params2.load('test_paramdict.params', mx.cpu())
+
+ # compare the values before and after save/load
+ cur_w0 = params2.get('w0').data(mx.cpu())
+ cur_w1 = params2.get('w1').data(mx.cpu())
+ mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
+ mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())
+
+
+@with_seed()
+def test_parameter_row_sparse_data():
+ ctx0 = mx.cpu(1)
+ ctx1 = mx.cpu(2)
+ dim0 = 4
+ x = gluon.Parameter('x', shape=(dim0, 2), stype='row_sparse')
+ x.initialize(init='xavier', ctx=[ctx0, ctx1])
+ trainer = gluon.Trainer([x], 'sgd')
+ x_param = x._data[0].copy()
+ assert x_param.stype == 'row_sparse'
+ row_id_0 = mx.nd.array([0,1], ctx=ctx0)
+ retained_0 = x.row_sparse_data(row_id_0)
+ retained_target_0 = mx.nd.sparse.retain(x_param,
row_id_0.as_in_context(ctx0))
+ mx.test_utils.assert_almost_equal(retained_0.asnumpy(),
retained_target_0.asnumpy())
+ assert retained_0.context == ctx0
+ row_id_1 = mx.nd.arange(0, dim0, ctx=ctx1)
+ retained_1 = x.row_sparse_data(row_id_1)
+ retained_target_1 = x_param
+ mx.test_utils.assert_almost_equal(retained_1.asnumpy(),
retained_target_1.asnumpy())
+ assert retained_1.context == ctx1
+ row_id_2 = mx.nd.array([0,1,2])
+ retained_2 = x.list_row_sparse_data(row_id_2)
+ retained_target_2 = mx.nd.sparse.retain(x_param,
row_id_2.as_in_context(ctx0))
+ mx.test_utils.assert_almost_equal(retained_2[0].asnumpy(),
retained_target_2.asnumpy())
@with_seed()
@@ -246,7 +334,29 @@ def hybrid_forward(self, F, x):
net.hybridize()
assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
+@with_seed()
+@raises(AssertionError)
+def test_sparse_symbol_block():
+ data = mx.sym.var('data')
+ weight = mx.sym.var('weight', stype='row_sparse')
+ bias = mx.sym.var('bias')
+ out = mx.sym.broadcast_add(mx.sym.dot(data, weight), bias)
+ # an exception is expected when creating a SparseBlock w/ sparse param
+ net = gluon.SymbolBlock(out, data)
+@with_seed()
+@raises(RuntimeError)
+def test_sparse_hybrid_block():
+ params = gluon.ParameterDict('net_')
+ params.get('weight', shape=(5,5), stype='row_sparse', dtype='float32')
+ params.get('bias', shape=(5,), dtype='float32')
+ net = gluon.nn.Dense(5, params=params)
+ net.initialize()
+ x = mx.nd.ones((2,5))
+ # an exception is expected when forwarding a HybridBlock w/ sparse param
+ y = net(x)
+
+@with_seed()
def check_layer_forward(layer, dshape):
layer.collect_params().initialize()
x = mx.nd.ones(shape=dshape)
@@ -496,80 +606,6 @@ def test_flatten():
x = mx.nd.zeros((3,))
assert flatten(x).shape == (3, 1)
-
-@with_seed()
-def test_trainer():
- def dict_equ(a, b):
- assert set(a) == set(b)
- for k in a:
- assert (a[k].asnumpy() == b[k].asnumpy()).all()
- x = gluon.Parameter('x', shape=(10,))
- x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
- trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0, 'momentum':
0.5})
- with mx.autograd.record():
- for w in x.list_data():
- y = w + 1
- y.backward()
- trainer.step(1)
-
- assert (x.data(mx.cpu(1)).asnumpy() == -2).all()
-
- x.lr_mult = 0.5
-
- with mx.autograd.record():
- for w in x.list_data():
- y = w + 1
- y.backward()
- trainer.step(1)
-
- assert (x.data(mx.cpu(1)).asnumpy() == -4).all()
-
- trainer.save_states('test_trainer.states')
- states = deepcopy(trainer._kvstore._updater.states) if
trainer._update_on_kvstore \
- else deepcopy(trainer._updaters[0].states)
- trainer.load_states('test_trainer.states')
- if trainer._update_on_kvstore:
- dict_equ(trainer._kvstore._updater.states, states)
- assert trainer._optimizer == trainer._kvstore._updater.optimizer
- else:
- for updater in trainer._updaters:
- dict_equ(updater.states, states)
- assert trainer._optimizer == trainer._updaters[0].optimizer
- assert_raises(AssertionError, trainer.update, 1)
- assert_raises(AssertionError, trainer.allreduce_grads)
-
- x = gluon.Parameter('x', shape=(10,))
- x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
- trainer2 = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0, 'momentum':
0.5},
- update_on_kvstore=False)
- with mx.autograd.record():
- for i, w in enumerate(x.list_data()):
- y = i*w
- y.backward()
- assert (x.grad(mx.cpu(0)).asnumpy() != x.grad(mx.cpu(1)).asnumpy()).all()
- trainer2.allreduce_grads()
- assert (x.grad(mx.cpu(0)).asnumpy() == x.grad(mx.cpu(1)).asnumpy()).all()
- trainer2.update(1)
-
- assert (x.data(mx.cpu(1)).asnumpy() == -1).all(),
x.data(mx.cpu(1)).asnumpy()
-
-@with_seed()
-def test_trainer_save_load():
- x = gluon.Parameter('x', shape=(10,), lr_mult=1.0)
- x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
- trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
- with mx.autograd.record():
- for w in x.list_data():
- y = w + 1
- y.backward()
- trainer.step(1)
- assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
- trainer.save_states('test_trainer_save_load.states')
- trainer.load_states('test_trainer_save_load.states')
- x.lr_mult = 2.0
- # check if parameter dict is correctly associated with optimizer after
load_state
- assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
-
@with_seed()
def test_block_attr_hidden():
b = gluon.Block()
@@ -900,6 +936,7 @@ def test_inline():
assert len_1 == len_2 + 2
+@with_seed()
def test_activations():
point_to_validate = mx.nd.array([-0.1, 0.1] * 3)
@@ -1013,13 +1050,14 @@ def test_req():
@with_seed()
def test_save_load():
net = mx.gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True)
- net.save_params('test.params')
+ net.save_params('test_save_load.params')
net = mx.gluon.model_zoo.vision.get_resnet(1, 18)
net.output = mx.gluon.nn.Dense(1000)
- net.load_params('test.params')
+ net.load_params('test_save_load.params')
+@with_seed()
def test_symbol_block_save_load():
class Net(gluon.HybridBlock):
def __init__(self):
@@ -1042,10 +1080,10 @@ def hybrid_forward(self, F, x):
net1.initialize(mx.init.Normal())
net1.hybridize()
net1(mx.nd.random.normal(shape=(1, 3, 32, 32)))
- net1.save_params('./test.params')
+ net1.save_params('./test_symbol_block_save_load.params')
net2 = Net()
- net2.load_params('./test.params', ctx=mx.cpu())
+ net2.load_params('./test_symbol_block_save_load.params', ctx=mx.cpu())
@with_seed()
diff --git a/tests/python/unittest/test_gluon_trainer.py
b/tests/python/unittest/test_gluon_trainer.py
new file mode 100644
index 00000000000..c2e11ebb18e
--- /dev/null
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -0,0 +1,200 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import unittest
+import numpy as np
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.test_utils import assert_almost_equal
+from common import setup_module, with_seed, assertRaises
+from copy import deepcopy
+from nose.tools import raises, assert_raises
+
+@with_seed()
+@raises(RuntimeError)
+def test_multi_trainer():
+ x = gluon.Parameter('x', shape=(10,), stype='row_sparse')
+ x.initialize()
+ # test set trainer
+ trainer0 = gluon.Trainer([x], 'sgd')
+ assert(x._trainer is trainer0)
+ # test unset trainer
+ x._set_trainer(None)
+ assert(x._trainer is None)
+ x._set_trainer(trainer0)
+ # multiple trainers for a sparse Parameter is not allowed
+ trainer1 = gluon.Trainer([x], 'sgd')
+
+@with_seed()
+def test_trainer():
+ def dict_equ(a, b):
+ assert set(a) == set(b)
+ for k in a:
+ assert (a[k].asnumpy() == b[k].asnumpy()).all()
+ x = gluon.Parameter('x', shape=(10,))
+ x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+ trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0, 'momentum':
0.5})
+ with mx.autograd.record():
+ for w in x.list_data():
+ y = w + 1
+ y.backward()
+ trainer.step(1)
+
+ assert (x.data(mx.cpu(1)).asnumpy() == -2).all()
+
+ x.lr_mult = 0.5
+
+ with mx.autograd.record():
+ for w in x.list_data():
+ y = w + 1
+ y.backward()
+ trainer.step(1)
+
+ assert (x.data(mx.cpu(1)).asnumpy() == -4).all()
+
+ trainer.save_states('test_trainer.states')
+ states = deepcopy(trainer._kvstore._updater.states) if
trainer._update_on_kvstore \
+ else deepcopy(trainer._updaters[0].states)
+ trainer.load_states('test_trainer.states')
+ if trainer._update_on_kvstore:
+ dict_equ(trainer._kvstore._updater.states, states)
+ assert trainer._optimizer == trainer._kvstore._updater.optimizer
+ else:
+ for updater in trainer._updaters:
+ dict_equ(updater.states, states)
+ assert trainer._optimizer == trainer._updaters[0].optimizer
+ assert_raises(AssertionError, trainer.update, 1)
+ assert_raises(AssertionError, trainer.allreduce_grads)
+
+ x = gluon.Parameter('x', shape=(10,))
+ x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+ trainer2 = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0, 'momentum':
0.5},
+ update_on_kvstore=False)
+ with mx.autograd.record():
+ for i, w in enumerate(x.list_data()):
+ y = i*w
+ y.backward()
+ assert (x.grad(mx.cpu(0)).asnumpy() != x.grad(mx.cpu(1)).asnumpy()).all()
+ trainer2.allreduce_grads()
+ assert (x.grad(mx.cpu(0)).asnumpy() == x.grad(mx.cpu(1)).asnumpy()).all()
+ trainer2.update(1)
+
+ assert (x.data(mx.cpu(1)).asnumpy() == -1).all(),
x.data(mx.cpu(1)).asnumpy()
+
+@with_seed()
+def test_trainer_save_load():
+ x = gluon.Parameter('x', shape=(10,), lr_mult=1.0)
+ x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+ trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
+ with mx.autograd.record():
+ for w in x.list_data():
+ y = w + 1
+ y.backward()
+ trainer.step(1)
+ assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
+ trainer.save_states('test_trainer_save_load.states')
+ trainer.load_states('test_trainer_save_load.states')
+ x.lr_mult = 2.0
+ # check if parameter dict is correctly associated with optimizer after
load_state
+ assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
+
+@with_seed()
+def test_trainer_multi_layer_init():
+ class Net(gluon.Block):
+ def __init__(self, **kwargs):
+ super(Net, self).__init__(**kwargs)
+ with self.name_scope():
+ # sparse param
+ self.embed_weight = self.params.get('embed_weight',
stype='row_sparse',
+ shape=(4,3),
grad_stype='row_sparse')
+ # dense param from a hybrid block
+ self.dense0 = nn.Dense(2)
+
+ def forward(self, x):
+ embed_weight = self.embed_weight.row_sparse_data(x)
+ embed = mx.nd.Embedding(data=x, weight=embed_weight,
+ input_dim=4, output_dim=3,
sparse_grad=True)
+ return self.dense0(embed)
+
+ def check_init(ctxes):
+ net = Net(prefix='net_')
+ net.initialize(mx.init.One(), ctx=ctxes)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':
1})
+ data = mx.nd.array([[0,2], [1,2]])
+ xs = gluon.utils.split_and_load(data, ctxes)
+ ys = []
+ with mx.autograd.record():
+ for x in xs:
+ y = net(x)
+ ys.append(y)
+ for y in ys:
+ y.backward()
+ trainer.step(1)
+ # all parameters should be initialized
+ assert not trainer._params_to_init
+ all_rows = mx.nd.arange(0, 4, ctx=mx.cpu(1))
+ # check the updated weights
+ weight = net.embed_weight.row_sparse_data(all_rows).asnumpy()
+ assert (weight[0] == -1).all()
+ assert (weight[1] == -1).all()
+ assert (weight[2] == -3).all()
+ assert (weight[3] == 1).all()
+
+ check_init([mx.cpu(1), mx.cpu(2)])
+ check_init([mx.cpu(1)])
+
+@with_seed()
+def test_trainer_save_load():
+ x = gluon.Parameter('x', shape=(10,), lr_mult=1.0)
+ x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+ trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
+ with mx.autograd.record():
+ for w in x.list_data():
+ y = w + 1
+ y.backward()
+ trainer.step(1)
+ assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
+ trainer.save_states('test_trainer_save_load.states')
+ trainer.load_states('test_trainer_save_load.states')
+ x.lr_mult = 2.0
+ # check if parameter dict is correctly associated with optimizer after
load_state
+ assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
+
+@with_seed()
+def test_trainer_reset_kv():
+ params = gluon.ParameterDict()
+ x = params.get('x', shape=(10,), lr_mult=1.0)
+ params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+ trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1})
+ params.save('test_trainer_reset_kv.params')
+ with mx.autograd.record():
+ for w in x.list_data():
+ y = w + 1
+ y.backward()
+ trainer.step(1)
+ # load would reset kvstore
+ params.load('test_trainer_reset_kv.params')
+ assert trainer._kvstore is None
+ assert trainer._kv_initialized is False
+ with mx.autograd.record():
+ for w in x.list_data():
+ y = w + 1
+ y.backward()
+ trainer.step(1)
+ # the updated parameter should be based on the loaded checkpoint
+ assert (x.data(mx.cpu()) == -0.2).asnumpy().all()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services