eric-haibin-lin opened a new issue #17722: Large Embedding Model Support in MXNet URL: https://github.com/apache/incubator-mxnet/issues/17722 # Problem Statement For search, ads or recommenders systems, the embeddings are large and pose challenges on the deep learning system for efficient computation, memory, storage and networking. For some use cases, the embedding table may be up to hundreds of gigabytes, which does not fit into a single node CPU’s memory. The proposal below describe one way to enable such kind of models on the existing MXNet system. # Current State in MXNet ### sparse storage types Two sparse storage types are supported in MXNet: compressed sparse row (CSR) and row sparse. [Row Sparse NDArray](https://mxnet.apache.org/api/python/docs/tutorials/packages/ndarray/sparse/row_sparse.html) * used to represent sparse weight and sparse gradient [CSR NDArray](https://mxnet.apache.org/api/python/docs/tutorials/packages/ndarray/sparse/csr.html) * used to represent sparse data input Two typical use cases: 1. x is a csr_matrx, w is a row_sparse array. Sparse matrix matrix multiplication is performed: 1. forward computation y = dot(x, w), which generates a dense output 2. backward computation dw = dot(x.T, dy), which generates a row_sparse gradient 1. the non-zero row ids in dw corresponds to column ids in x 2. x is a dense array, w is a row_sparse array for embedding. 1. forward computation y = embedding(x, w), which generates a dense output 2. backward computation dw is a row_sparse gradient. 1. the non-zero rows ids in dw corresponds to the values in x ### sparse support in KVStore The KVStore in MXNet supports both dense and row_sparse ndarrays. * kv.init(key, value) * initialize the “key” with “value” in KVStore. supports both dense and row_sparse. * kv.push(key, [values]) * accumulate the list of values and push to kvstore. * when the updater is set on kvstore, trigger SGD update on server * supports both dense and row_sparse * kv.pull(key, out=output) * pull latest value of “key” from server to the output array * only supports dense arrays * kv.row_sparse_pull(key, row_ids, out=output) * pull latest value of “key” on a subset of rows specified by “row_ids” from server * only supports row_sparse arrays ### distributed training with gluon ``` # network net = Net() net.initialize(mx.init.Uniform(), ctx=contexts) # trainer kvstore = mx.kv.create(‘dist’) trainer = Trainer(net.collect_params(), kv=kvstore) # training loop for batch in batches: xs = split_and_load(batch, ctx=contexts) losses = [] with autograd.record(): for x in xs: loss = net(x) losses.append(loss) autograd.backward(losses) trainer.step() # performs kvstore.push, kvstore.pull # serialization after done with training net.save_parameters() trainer.save_states() # for resuming optimizer states ``` ### problem with the current design In the current design, MXNet assumes the model fits on a single GPU, and reports OOM error if the model is too large. A few component in the current implementation leads to the OOM error: * initialization * the model parameters are first initialized on the CPU * then it is copied to GPUs, and sent to the server for initialization * kvstore reduce buffer * the kvstore reduce buffer is allocated in a round-robin way across keys. * the buffer size for each key is assumed to be fit on GPUs * model serialization * model parameters are always saved on the worker side, assuming it has access to all parameters # Proposed Solution <img width="641" alt="Screen Shot 2020-02-23 at 4 59 16 PM" src="https://user-images.githubusercontent.com/5545640/75595758-868dc200-5a42-11ea-8ce6-a9fd0691ece8.png"> ## end user experience ``` # network kvstore = mx.kv.create(‘dist’) embedding = gluon.sparse.Embedding(sparse_weight=True, sparse_grad=True) embedding.initialize(mx.init.Uniform(), ctx=[mx.cpu()], init_from_kv=kvstore) net = Net() net.initialize(mx.init.Uniform(), ctx=[mx.gpu(0), mx.gpu(1)]) # trainer trainer = Trainer(net.collect_params() + embedding.collect_params(), kv=kvstore) # training loop for batch in batches: # pull back latest embedding weights for mini-batch kvstore.row_sparse_pull('embedding', row_ids=batch.indices, out=embedding.weight.data()) xs = split_and_load(batch, ctx=contexts) losses = [] with autograd.record(): for x in xs: embed = embedding(x).copyto(mx.gpu()) loss = net(embed) losses.append(loss) autograd.backward(losses) trainer.step() # performs kvstore.push, kvstore.pull # Serialization after done with training net.save_parameters() trainer.save_states() # for optimizer states and remote embeddings ``` ### Functionality: Lazy Initialization for Gluon Sparse Parameters By default, net.initialize(initializer, contexts) initializes parameters with random weight and copy the weight to target contexts. For our use case, we add a init_from_kv flag that sets the initializer on KVStore server instead of running initializer locally. ### Functionality: Initializer on KVStore Servers We add initializer argument to the kvstore API: kvstore.init(key, array, initializer=mx.init.Uniform()) such that the kvstore is registered with the initializer (similar to how optimizers are registered). On pull requests, the kvstore server first checks if the tensor is initialized or not, if not it runs the initializer before returning the result. ### Functionality: Model Checkpoints on KVStore Servers The parameter server requires the functionality to perform parameter saving on the server side, and some utility function to merge all keys if the parameters are sharded. ### Optimization: Data and Parameter Pre-fetching We can rewrite the training loop to perform pre-fetching: ``` next_batch = batch_iter.next() kvstore.row_sparse_pull('embedding', row_ids=next_batch.indices, out=embedding.weight.data()) while True: batch = next_batch xs = split_and_load(batch, ctx=contexts) losses = [] with autograd.record(): for x in xs: embed = embedding(x).copyto(mx.gpu()) loss = net(embed) losses.append(loss) autograd.backward(losses) trainer.step() # performs kvstore.push, kvstore.pull try: next_batch = batch_iter.next() kvstore.row_sparse_pull('embedding', row_ids=next_batch.indices, out=embedding.weight.data()) except StopIteartion: break # end of an epoch ``` ### Optimization: Parameter Slicing and Scheduling for Large Embeddings Since the embedding layer is extremely large, in kvstore we generate 1 key per row in the embedding tensor. We shard the key space across available parameter servers. The slicing and sharding strategy can refer to the p3 kvstore implementation in MXNet. ### Optimization: CPU performance optimization Embedding lookup on CPU may be accelerated with fast hash table implementation. The parameter update on server maybe accelerated with bfloat16 implementation. The embedding gradient can also be cast to half precision to reduce communication traffic. # Appendix distributed training with module Below is a sketch of a network with sparse weight and sparse gradient, with the module API. ``` * mod = mx.mod.Module(sparse_symbol) * mod.bind() * mod.init_params() * initialize params on CPU * copy to params on GPUs * mod.init_optimizer() * mod.prepare(batch, sparse_row_id_fn) * kv.row_sparse_pull(sparse_layer, row_id) # pulls the subset of rows to workers * mod.forward_backward() * mod.update() * mod.prepare(batch, all_row_id_fn) * kv.row_sparse_pull(sparse_layer, all_rows) # pulls all latest rows to workers * mod.save_checkpoint() * serialize optimizer state * serialize model parameter ```
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
