piiswrong closed pull request #11197: Gluon sparse block and sparse embedding URL: https://github.com/apache/incubator-mxnet/pull/11197
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/docs/api/python/gluon/contrib.md b/docs/api/python/gluon/contrib.md index bc3089fa878..877a294d9a1 100644 --- a/docs/api/python/gluon/contrib.md +++ b/docs/api/python/gluon/contrib.md @@ -35,6 +35,7 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p Concurrent HybridConcurrent Identity + SparseEmbedding ``` ### Recurrent neural network @@ -55,6 +56,7 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p Conv1DGRUCell Conv2DGRUCell Conv3DGRUCell + LSTMPCell ``` ### Data diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index eccdf18c1bb..1edef1476ee 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -18,10 +18,10 @@ # coding: utf-8 # pylint: disable= arguments-differ """Custom neural network layers in model_zoo.""" -__all__ = ['Concurrent', 'HybridConcurrent', 'Identity'] +__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding'] from .... import nd -from ...block import HybridBlock +from ...block import HybridBlock, Block from ...nn import Sequential, HybridSequential class Concurrent(Sequential): @@ -110,3 +110,44 @@ def __init__(self, prefix=None, params=None): def hybrid_forward(self, F, x): return x + +class SparseEmbedding(Block): + r"""Turns non-negative integers (indexes/tokens) into dense vectors + of fixed size. eg. [4, 20] -> [[0.25, 0.1], [0.6, -0.2]] + + This SparseBlock is designed for distributed training with extremely large + input dimension. Both weight and gradient w.r.t. weight are `RowSparseNDArray`. + + Parameters + ---------- + input_dim : int + Size of the vocabulary, i.e. maximum integer index + 1. + output_dim : int + Dimension of the dense embedding. + dtype : str or np.dtype, default 'float32' + Data type of output embeddings. + weight_initializer : Initializer + Initializer for the `embeddings` matrix. + + Inputs: + - **data**: (N-1)-D tensor with shape: `(x1, x2, ..., xN-1)`. + Output: + - **out**: N-D tensor with shape: `(x1, x2, ..., xN-1, output_dim)`. + """ + def __init__(self, input_dim, output_dim, dtype='float32', + weight_initializer=None, **kwargs): + super(SparseEmbedding, self).__init__(**kwargs) + self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim, + 'dtype': dtype, 'sparse_grad': True} + self.weight = self.params.get('weight', shape=(input_dim, output_dim), + init=weight_initializer, dtype=dtype, + grad_stype='row_sparse', stype='row_sparse') + + def forward(self, x): + weight = self.weight.row_sparse_data(x) + return nd.Embedding(x, weight, name='fwd', **self._kwargs) + + def __repr__(self): + s = '{block_name}({input_dim} -> {output_dim}, {dtype})' + return s.format(block_name=self.__class__.__name__, + **self._kwargs) diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index 729ec8407f2..264ff1f5e53 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -19,7 +19,7 @@ import mxnet as mx from mxnet.gluon import contrib from mxnet.gluon import nn -from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity +from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity, SparseEmbedding from mxnet.test_utils import almost_equal from common import setup_module, with_seed import numpy as np @@ -185,13 +185,25 @@ def test_concurrent(): x.wait_to_read() x2.wait_to_read() - +@with_seed() def test_identity(): model = Identity() x = mx.nd.random.uniform(shape=(128, 33, 64)) mx.test_utils.assert_almost_equal(model(x).asnumpy(), x.asnumpy()) +@with_seed() +def test_sparse_embedding(): + layer = SparseEmbedding(10, 100) + layer.initialize() + trainer = mx.gluon.Trainer(layer.collect_params(), 'sgd') + x = mx.nd.array([3,4,2,0,1]) + with mx.autograd.record(): + y = layer(x) + y.backward() + assert (layer.weight.grad().asnumpy()[:5] == 1).all() + assert (layer.weight.grad().asnumpy()[5:] == 0).all() + def test_datasets(): wikitext2_train = contrib.data.text.WikiText2(root='data/wikitext-2', segment='train') wikitext2_val = contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation', ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services