This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 82a3d21  Added sparsity functionality, with tests (#7138)
82a3d21 is described below

commit 82a3d21104c348610d6f5e224c89a4382302725f
Author: Guneet Singh Dhillon <[email protected]>
AuthorDate: Thu Aug 3 13:08:14 2017 -0700

    Added sparsity functionality, with tests (#7138)
    
    * added pruning for sgd
    
    * added pruning for example/image-classification
    
    * working example for imagenet to experiment on
    
    * added flexibility to start off with pruning
    
    * changes to imagenet code
    
    * minor changes for testing
    
    * changes to imagenet pruning
    
    * small changes to parameters for tests
    
    * DSD test on mnist added
    
    * improved sparsification, added sparse-sparse training, added pruning 
factor
    
    * changed test for more coverage
    
    * updated example
    
    * updated example to save models
    
    * added thresholding by user
    
    * made optimizer code cleaner, created tests - mlp and rnn
    
    * added thresholding functionality, and related tests
    
    * made minor change to tests
    
    * updated common file, changed to merger
    
    * merging
    
    * reverted for mshadow
    
    * reverted dmlc-core
    
    * back to old examples
    
    * removed spaces from code
    
    * added comments
    
    * another style change
    
    * made SparseSGD a subclass
    
    * removed dependencies from tests
    
    * minor changes
    
    * reduced checks - not needed
    
    * call sgd from sparsesgd
    
    * corrected syntax
    
    * corrected syntax
    
    * reverted back, handle epoch count myself
    
    * added DSD traning to examples
    
    * added mask generation logic
    
    * added comment on layer-wise vs global pruning
    
    * added update message in sparse_sgd
    
    * added an example
    
    * changes to README
---
 example/dsd/README.md     |  30 ++++++++
 example/dsd/mlp.py        | 125 ++++++++++++++++++++++++++++++++++
 example/dsd/sparse_sgd.py | 170 ++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 325 insertions(+)

diff --git a/example/dsd/README.md b/example/dsd/README.md
new file mode 100644
index 0000000..0ce5cc5
--- /dev/null
+++ b/example/dsd/README.md
@@ -0,0 +1,30 @@
+DSD Training
+============
+This folder contains an optimizer class that implements DSD training coupled 
with SGD. The training
+procedure is described in the paper *DSD: Dense-Sparse-Dense Training for Deep 
Neural Networks*,
+available at https://arxiv.org/pdf/1607.04381.pdf
+
+The optimizer class is flexible in the way it prunes weights. The user can 
define the following:
+-   The percentage sparsity they want or the thresholding value for the pruning
+-   The epochs at which they want a particular level of pruning
+
+Note that giving the sparsity level induces that level of sparsity in every 
layer of the neural
+network. It layer-wise pruning, and not global pruning (which would require 
loooking at all the
+weights of the neural network at the same time). However, global pruning can 
be done if the
+threshold value is known to the user (by doing some preprocessing), and is 
passed to the optimizer.
+
+## Example
+
+To test out the sparsity feature on a MLP, run the following script:
+
+    python mlp.py --pruning_switch_epoch 4,7,10 --bias_sparsity 0,30,50 
--weight_sparsity 0,50,70
+
+This will train a MLP with 0% sparsity uptil epoch 4, with 30% bias and 50% 
weight sparsity uptil
+epoch 7, 50% bias and 70% weight sparsity uptil epoch 10.
+
+To test out the thresholding feature on a MLP, run the following script:
+
+    python mlp.py --pruning_switch_epoch 4,6 --bias_threshold 0,0.01 
--weight_threshold 0,0.05
+
+This will train a MLP with thresholding at 0 uptil epoch 4, with bias 
thresholding at 0.01 and
+weight thresholding at 0.05 uptil epoch 6.
diff --git a/example/dsd/mlp.py b/example/dsd/mlp.py
new file mode 100644
index 0000000..ccb0940
--- /dev/null
+++ b/example/dsd/mlp.py
@@ -0,0 +1,125 @@
+import mxnet as mx
+import os
+import logging
+import argparse
+from math import ceil
+import sparse_sgd
+
+# symbol net
+def get_symbol():
+    data = mx.symbol.Variable('data')
+    fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
+    act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
+    fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
+    act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
+    fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
+    softmax = mx.symbol.SoftmaxOutput(fc3, name='sm')
+
+    return softmax
+
+# download ubyte version of mnist and untar
+def download_data():
+    if not os.path.isdir("data/"):
+        os.system("mkdir data/")
+    if (not os.path.exists('data/train-images-idx3-ubyte')) or \
+       (not os.path.exists('data/train-labels-idx1-ubyte')) or \
+       (not os.path.exists('data/t10k-images-idx3-ubyte')) or \
+       (not os.path.exists('data/t10k-labels-idx1-ubyte')):
+        os.system("wget -q http://data.mxnet.io/mxnet/data/mnist.zip -P data/")
+        os.chdir("./data")
+        os.system("unzip -u mnist.zip")
+        os.chdir("..")
+
+# get data iterators
+def get_iters(batch_size):
+    train = mx.io.MNISTIter(
+        image="data/train-images-idx3-ubyte",
+        label="data/train-labels-idx1-ubyte",
+        data_shape=(784,),
+        label_name='sm_label',
+        batch_size=batch_size,
+        shuffle=True,
+        flat=True,
+        silent=False,
+        seed=10)
+    val = mx.io.MNISTIter(
+        image="data/t10k-images-idx3-ubyte",
+        label="data/t10k-labels-idx1-ubyte",
+        data_shape=(784,),
+        label_name='sm_label',
+        batch_size=batch_size,
+        shuffle=True,
+        flat=True,
+        silent=False)
+
+    return (train, val)
+
+def test_mlp(args):
+    # get parameters
+    prefix = './mlp'
+    batch_size = 100
+    pruning_switch_epoch = [int(i) for i in 
args.pruning_switch_epoch.split(',')]
+    num_epoch = pruning_switch_epoch[-1]
+    batches_per_epoch = ceil(60000.0/batch_size)
+    weight_sparsity = args.weight_sparsity
+    bias_sparsity = args.bias_sparsity
+    weight_threshold = args.weight_threshold
+    bias_threshold = args.bias_threshold
+    if args.weight_sparsity:
+        weight_sparsity = [float(i) for i in args.weight_sparsity.split(',')]
+        bias_sparsity = [float(i) for i in args.bias_sparsity.split(',')]
+    else:
+        weight_threshold = [float(i) for i in args.weight_threshold.split(',')]
+        bias_threshold = [float(i) for i in args.bias_threshold.split(',')]
+
+    # get symbols and iterators
+    sym = get_symbol()
+    download_data()
+    (train, val) = get_iters(batch_size)
+
+    # fit model
+    model = mx.mod.Module(
+        sym,
+        context=[mx.cpu(i) for i in range(2)],
+        data_names=['data'],
+        label_names=['sm_label'])
+    optimizer_params = {
+        'learning_rate'             : 0.1,
+        'wd'                        : 0.004,
+        'momentum'                  : 0.9,
+        'pruning_switch_epoch'      : pruning_switch_epoch,
+        'batches_per_epoch'         : batches_per_epoch,
+        'weight_sparsity'           : weight_sparsity,
+        'bias_sparsity'             : bias_sparsity,
+        'weight_threshold'          : weight_threshold,
+        'bias_threshold'            : bias_threshold}
+    logging.info('Start training...')
+    model.fit(train,
+        eval_data=val,
+        eval_metric='acc',
+        epoch_end_callback=mx.callback.do_checkpoint(prefix),
+        num_epoch=num_epoch,
+        optimizer='sparsesgd',
+        optimizer_params=optimizer_params)
+    logging.info('Finish traning...')
+
+    # remove files
+    for i in range(num_epoch):
+        os.remove('%s-%04d.params' % (prefix, i + 1))
+    os.remove('%s-symbol.json' % prefix)
+
+
+if __name__ == "__main__":
+
+    # print logging by default
+    logging.basicConfig(level=logging.DEBUG)
+
+    parser = argparse.ArgumentParser(description="sparse training")
+    parser.add_argument('--pruning_switch_epoch', type=str)
+    parser.add_argument('--weight_sparsity', type=str, default=None)
+    parser.add_argument('--bias_sparsity', type=str, default=None)
+    parser.add_argument('--weight_threshold', type=str, default=None)
+    parser.add_argument('--bias_threshold', type=str, default=None)
+    args = parser.parse_args()
+
+    test_mlp(args)
diff --git a/example/dsd/sparse_sgd.py b/example/dsd/sparse_sgd.py
new file mode 100644
index 0000000..f11a239
--- /dev/null
+++ b/example/dsd/sparse_sgd.py
@@ -0,0 +1,170 @@
+from mxnet.ndarray import NDArray, topk, abs as NDabs
+from mxnet.optimizer import SGD, register
+import logging
+
+log = 'Sparsity Update:\t'
+
+@register
+class SparseSGD(SGD):
+    """The SGD optimizer with weight pruning.
+
+    This class implements the optimizer described in the paper *DSD: 
Dense-Sparse-Dense Training for
+    Deep Neural Networks*, available at https://arxiv.org/pdf/1607.04381.pdf
+
+    The optimizer updates the weights the same way as done in SGD, but does 
the following
+    preprocessing::
+
+        if threshold given, all weights below the threshold in absolute value 
are pruned,
+            mask    =   abs(weight) >= threshold
+        if sparsity level given, the smallest (sparsity)% weights in absolute 
value are pruned
+        (or the largest (100-sparsity)% weights in absolute value are used)
+            mask    =   topk(abs(weight), ret_typ='mask', 
k=weight.size*(100-sparsity)/100)
+
+        => mask[i,j]    =   {0 if weight[i,j] is pruned, 1 otherwise} (for a 
matrix representation)
+
+        weight  =   weight  *   mask
+        grad    =   grad    *   mask
+        state   =   state   *   mask
+
+    This optimizer accepts the following parameters in addition to those 
accepted
+    by :class:`.SGD`.
+
+    Parameters
+    ----------
+    pruning_switch_epoch : list of ints, optional
+        The epochs at which there is a change in sparsity level (should be in 
ascending order).
+
+    weight_sparsity : list of floats, optional
+        The sparsity on the weights required on each iteration of sparse 
training.
+
+    bias_sparsity : list of floats, optional
+        The sparsity on the biases required on each iteration of sparse 
training.
+
+    weight_threshold : list of floats, optional
+        The absolute value threshold on the weights required on each iteration 
of sparse training.
+
+    bias_threshold : list of floats, optional
+        The absolute value threshold on the biases required on each iteration 
of sparse training.
+
+    batches_per_epoch : int, optional
+        The number of batches in each epoch.
+        (The ceiling integer value of number_of_examples / batch_size)
+    """
+    def __init__(self, pruning_switch_epoch, batches_per_epoch,
+                 weight_sparsity=None, bias_sparsity=None,
+                 weight_threshold=None, bias_threshold=None, **kwargs):
+        super(SparseSGD, self).__init__(**kwargs)
+
+        self.masks = []
+        self.masks_updated = False
+        self.epoch = 0
+        self.pruning_switch_epoch = pruning_switch_epoch
+        self.batches_per_epoch = batches_per_epoch
+
+        # get weight and bias sparsity percentages
+        self.weight_sparsity = weight_sparsity
+        self.bias_sparsity = bias_sparsity
+        if weight_sparsity is not None:
+            assert len(weight_sparsity) == len(bias_sparsity), \
+                'weight_sparsity and bias_sparsity should have same length'
+            assert len(weight_sparsity) == len(pruning_switch_epoch), \
+                'pruning_switch_epoch and weight_sparsity should have same 
length'
+
+        # get weight and bias sparsity thresholds
+        self.weight_threshold = weight_threshold
+        self.bias_threshold = bias_threshold
+        if weight_threshold is not None:
+            assert len(weight_threshold) == len(bias_threshold), \
+                'weight_threshold and bias_threshold should have same length'
+            assert len(weight_threshold) == len(pruning_switch_epoch), \
+                'pruning_switch_epoch and weight_sparsity_threshold should 
have same length'
+
+        # either percentages or thresholds must be given
+        assert weight_sparsity is not None or weight_threshold is not None,\
+            'weight_sparsity or weight_sparsity_threshold should be given'
+
+    def update_masks(self, index, weight):
+        """Updates the masks for sparse training.
+
+        Parameters
+        ----------
+        index : int
+            The index for weight.
+        weight : NDArray
+            The weight matrix.
+
+        Returns
+        -------
+        boolean
+            If the masks were changed
+        """
+        # determine number of updates without actually updating the count
+        if index not in self._index_update_count:
+            num_update = self.begin_num_update
+        else:
+            num_update = self._index_update_count[index]
+        num_update += 1
+        num_update = max(num_update, self.num_update)
+
+        # calculate epoch
+        epoch = int((num_update - 1) / self.batches_per_epoch) + 1
+
+        # determine if masks need to be updated, and get corresponding 
parameters
+        if index == 0:
+            self.masks_updated = True
+        if self.epoch != epoch:
+            self.epoch = epoch
+            if epoch == 1:
+                self.masks_updated = False
+                if self.weight_sparsity is not None:
+                    logging.info(log + 'bias-sparsity={}, 
weight-sparsity={}'.format(self.bias_sparsity[0], self.weight_sparsity[0]))
+                else:
+                    logging.info(log + 'bias-threshold={}, 
weight-threshold={}'.format(self.bias_threshold[0], self.weight_threshold[0]))
+            if self.pruning_switch_epoch[0] + 1 == epoch:
+                self.masks_updated = False
+                self.pruning_switch_epoch.pop(0)
+                if self.weight_sparsity is not None:
+                    self.weight_sparsity.pop(0)
+                    self.bias_sparsity.pop(0)
+                    logging.info(log + 'bias-sparsity={}, 
weight-sparsity={}'.format(self.bias_sparsity[0], self.weight_sparsity[0]))
+                else:
+                    self.weight_threshold.pop(0)
+                    self.bias_threshold.pop(0)
+                    logging.info(log + 'bias-threshold={}, 
weight-threshold={}'.format(self.bias_threshold[0], self.weight_threshold[0]))
+
+        # update masks if needed
+        if not self.masks_updated:
+            # initialize masks
+            if epoch == 1:
+                self.masks.append(None)
+            # if percentages are given
+            if self.weight_sparsity is not None:
+                if len(weight.shape) == 1:
+                    sparsity = self.bias_sparsity[0]
+                else:
+                    sparsity = self.weight_sparsity[0]
+                number_unpruned = int((100.0 - sparsity) * weight.size / 100.0)
+                self.masks[index] = topk(NDabs(weight), axis=None, 
ret_typ='mask',
+                                         k=number_unpruned)
+            # if thresholds are given
+            else:
+                if len(weight.shape) == 1:
+                    threshold = self.bias_threshold[0]
+                else:
+                    threshold = self.weight_threshold[0]
+                self.masks[index] = NDabs(weight) >= threshold
+
+        return not self.masks_updated
+
+    def update(self, index, weight, grad, state):
+        assert(isinstance(weight, NDArray))
+        assert(isinstance(grad, NDArray))
+
+        # preprocessing for pruning
+        if self.update_masks(index, weight):
+            weight[:] = weight * self.masks[index]
+        grad[:] = grad * self.masks[index]
+        if state is not None:
+            state[:] = state * self.masks[index]
+
+        super(SparseSGD, self).update(index, weight, grad, state)

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to