eric-haibin-lin opened a new issue #17722: Large Embedding Model Support in 
MXNet
URL: https://github.com/apache/incubator-mxnet/issues/17722
 
 
   # Problem Statement 
   
   For search, ads or recommenders systems, the embeddings are large and pose 
challenges on the deep learning system for efficient computation, memory, 
storage and networking. For some use cases, the embedding table may be up to 
hundreds of gigabytes, which does not fit into a single node CPU’s memory. The 
proposal below describe one way to enable such kind of models on the existing 
MXNet system.
   
   # Current State in MXNet
   
   ### sparse storage types
   
   Two sparse storage types are supported in MXNet: compressed sparse row (CSR) 
and row sparse. 
   
   [Row Sparse 
NDArray](https://mxnet.apache.org/api/python/docs/tutorials/packages/ndarray/sparse/row_sparse.html)
   
   * used to represent sparse weight and sparse gradient
   
   [CSR 
NDArray](https://mxnet.apache.org/api/python/docs/tutorials/packages/ndarray/sparse/csr.html)
 
   
   * used to represent sparse data input 
   
   Two typical use cases:
   
   1. x is a csr_matrx, w is a row_sparse array. Sparse matrix matrix 
multiplication is performed:
       1. forward computation y = dot(x, w), which generates a dense output 
       2. backward computation dw = dot(x.T, dy), which generates a row_sparse 
gradient 
           1. the non-zero row ids in dw corresponds to column ids in x
   2. x is a dense array, w is a row_sparse array for embedding. 
       1. forward computation y = embedding(x, w), which generates a dense 
output
       2. backward computation dw is a row_sparse gradient. 
           1. the non-zero rows ids in dw corresponds to the values in x
   
   ### sparse support in KVStore
   
   The KVStore in MXNet supports both dense and row_sparse ndarrays. 
   
   * kv.init(key, value)
       * initialize the “key” with “value” in KVStore. supports both dense and 
row_sparse. 
   * kv.push(key, [values])
       * accumulate the list of values and push to kvstore. 
       * when the updater is set on kvstore, trigger SGD update on server
       * supports both dense and row_sparse
   * kv.pull(key, out=output) 
       * pull latest value of “key” from server to the output array
       * only supports dense arrays
   * kv.row_sparse_pull(key, row_ids, out=output)
       * pull latest value of “key” on a subset of rows specified by “row_ids” 
from server
       * only supports row_sparse arrays
   
   ### distributed training with gluon 
   ```
   # network
   net = Net()
   net.initialize(mx.init.Uniform(), ctx=contexts)
   # trainer 
   kvstore = mx.kv.create(‘dist’)
   trainer = Trainer(net.collect_params(), kv=kvstore)
   # training loop
   for batch in batches:
     xs = split_and_load(batch, ctx=contexts)
     losses = []
     with autograd.record():
       for x in xs:
         loss = net(x)
         losses.append(loss)
     autograd.backward(losses)
     trainer.step() # performs kvstore.push, kvstore.pull
   # serialization after done with training
   net.save_parameters() 
   trainer.save_states() # for resuming optimizer states
   ```
   ### problem with the current design 
   
   In the current design, MXNet assumes the model fits on a single GPU, and 
reports OOM error if the model is too large. A few component in the current 
implementation leads to the OOM error:
   
   * initialization 
       * the model parameters are first initialized on the CPU
       *  then it is copied to GPUs, and sent to the server for initialization
   * kvstore reduce buffer 
       * the kvstore reduce buffer is allocated in a round-robin way across 
keys.
       * the buffer size for each key is assumed to be fit on GPUs
   * model serialization
       * model parameters are always saved on the worker side, assuming it has 
access to all parameters
   
   # Proposed Solution
   
   <img width="641" alt="Screen Shot 2020-02-23 at 4 59 16 PM" 
src="https://user-images.githubusercontent.com/5545640/75595758-868dc200-5a42-11ea-8ce6-a9fd0691ece8.png";>
   
   ## end user experience
   ```
   # network
   kvstore = mx.kv.create(‘dist’)
   embedding = gluon.sparse.Embedding(sparse_weight=True, sparse_grad=True)
   embedding.initialize(mx.init.Uniform(), ctx=[mx.cpu()], init_from_kv=kvstore)
   net = Net()
   net.initialize(mx.init.Uniform(), ctx=[mx.gpu(0), mx.gpu(1)])
   # trainer
   trainer = Trainer(net.collect_params() + embedding.collect_params(), 
kv=kvstore)
   # training loop 
   for batch in batches:
     # pull back latest embedding weights for mini-batch
     kvstore.row_sparse_pull('embedding', row_ids=batch.indices,
                             out=embedding.weight.data())
     xs = split_and_load(batch, ctx=contexts)
     losses = []
     with autograd.record():
       for x in xs:
         embed = embedding(x).copyto(mx.gpu())
         loss = net(embed)
         losses.append(loss)
     autograd.backward(losses)
     trainer.step() # performs kvstore.push, kvstore.pull
   
   # Serialization after done with training
   net.save_parameters()
   trainer.save_states() # for optimizer states and remote embeddings
   ```
   
   ### Functionality: Lazy Initialization for Gluon Sparse Parameters
   
   By default, net.initialize(initializer, contexts) initializes parameters 
with random weight and copy the weight to target contexts. For our use case, we 
add a init_from_kv flag that sets the initializer on KVStore server instead of 
running initializer locally. 
   
   ### Functionality: Initializer on KVStore Servers
   
   We add initializer argument to the kvstore API: kvstore.init(key, array, 
initializer=mx.init.Uniform()) such that the kvstore is registered with the 
initializer (similar to how optimizers are registered). On pull requests, the 
kvstore server first checks if the tensor is initialized or not, if not it runs 
the initializer before returning the result. 
   
   ### Functionality: Model Checkpoints on KVStore Servers
   
   The parameter server requires the functionality to perform parameter saving 
on the server side, and some utility function to merge all keys if the 
parameters are sharded. 
   
   ### Optimization: Data and Parameter Pre-fetching
   
   We can rewrite the training loop to perform pre-fetching:
   ```
   next_batch = batch_iter.next()
   kvstore.row_sparse_pull('embedding', row_ids=next_batch.indices,
                           out=embedding.weight.data())
   while True:
     batch = next_batch
     xs = split_and_load(batch, ctx=contexts)
     losses = []
     with autograd.record():
       for x in xs:
         embed = embedding(x).copyto(mx.gpu())
         loss = net(embed)
         losses.append(loss)
     autograd.backward(losses)
     trainer.step() # performs kvstore.push, kvstore.pull
     try: 
       next_batch = batch_iter.next()
       kvstore.row_sparse_pull('embedding', row_ids=next_batch.indices,
                               out=embedding.weight.data())
     except StopIteartion: 
       break # end of an epoch 
   ```
   ### Optimization: Parameter Slicing and Scheduling for Large Embeddings 
   
   Since the embedding layer is extremely large, in kvstore we generate 1 key 
per row in the embedding tensor. We shard the key space across available 
parameter servers. The slicing and sharding strategy can refer to the p3 
kvstore implementation in MXNet. 
   
   ### Optimization: CPU performance optimization 
   
   Embedding lookup on CPU may be accelerated with fast hash table 
implementation. The parameter update on server maybe accelerated with bfloat16 
implementation. The embedding gradient can also be cast to half precision to 
reduce communication traffic. 
   
   # Appendix
   
   distributed training with module
   
   Below is a sketch of a network with sparse weight and sparse gradient, with 
the module API.  
   ```
   * mod = mx.mod.Module(sparse_symbol)
   * mod.bind()
   * mod.init_params()
       * initialize params on CPU
       * copy to params on GPUs
   * mod.init_optimizer()
   * mod.prepare(batch, sparse_row_id_fn)
       * kv.row_sparse_pull(sparse_layer, row_id) # pulls the subset of rows to 
workers
   * mod.forward_backward()
   * mod.update()
   * mod.prepare(batch, all_row_id_fn)
       * kv.row_sparse_pull(sparse_layer, all_rows) # pulls all latest rows to 
workers
   * mod.save_checkpoint()
       * serialize optimizer state
       * serialize model parameter
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to