This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a8c164  [MXNET-120] Float16 support for distributed training (#10183)
9a8c164 is described below

commit 9a8c16427f0a9842e0252f8f461e90d5b0cc3dd2
Author: Rahul Huilgol <rahulhuil...@gmail.com>
AuthorDate: Wed Apr 11 10:20:56 2018 -0700

    [MXNET-120] Float16 support for distributed training (#10183)
    
    * send as char
    
    * fix bug on pull response, and rowsparse on worker side
    
    * three modes
    
    * default to mode 0 and add support for row sparse
    
    * refactor sparse
    
    * rowsparse numbytes fixes
    
    * WIP tests
    
    * update test sync
    
    * remove prints
    
    * refactoring
    
    * Revert "refactoring"
    
    This reverts commit 05ffa1bf254057ec70ca6ec1a1deb3b072c31538.
    
    * undo refactoring to keep PR simple
    
    * add wait to stored in pull default
    
    * lint fixes
    
    * undo static cast for recvblob
    
    * lint fixes
    
    * mode 1 changes
    
    * sparse bug fix dtype
    
    * mshadow default
    
    * remove unused var
    
    * remove debug statements
    
    * clearer variables, reduced multiplication, const vars
    
    * add const for more vars, comments
    
    * comment syntax, code watcher, test default val
    
    * remove unnecessary print in test
    
    * trigger ci
    
    * multi precision mode (debugging race condition)
    
    * working rsp pushes
    
    * finish multiprecision for row sparse
    
    * rename num-bytes
    
    * fix bug due to rename of numbytes, and remove debug logs
    
    * address comments
    
    * add integration test
    
    * trigger ci
    
    * integration test
    
    * integration test
    
    * fix path of script
    
    * update mshadow
    
    * disable f16c for amalgamation
    
    * fix amalgamation build
    
    * trigger ci
    
    * disable f16c for jetson
---
 3rdparty/mshadow                   |   2 +-
 CODEOWNERS                         |   7 +-
 CONTRIBUTORS.md                    |   1 +
 Jenkinsfile                        |  11 +
 amalgamation/Makefile              |   9 +
 ci/docker/runtime_functions.sh     |   9 +
 make/crosscompile.jetson.mk        |   2 +-
 python/mxnet/gluon/trainer.py      |   5 +-
 python/mxnet/kvstore.py            |  14 +-
 python/mxnet/model.py              |   5 +-
 python/mxnet/module/module.py      |   7 +-
 src/kvstore/kvstore_dist.h         | 198 ++++++------
 src/kvstore/kvstore_dist_server.h  | 598 ++++++++++++++++++++++++-------------
 tests/nightly/dist_sync_kvstore.py | 335 ++++++++++++---------
 14 files changed, 753 insertions(+), 450 deletions(-)

diff --git a/3rdparty/mshadow b/3rdparty/mshadow
index f5b67f3..0b4cedd 160000
--- a/3rdparty/mshadow
+++ b/3rdparty/mshadow
@@ -1 +1 @@
-Subproject commit f5b67f380cb0588be11e6f440f92f013139380ee
+Subproject commit 0b4cedd7015cc69191f8338a8feaacda90697758
diff --git a/CODEOWNERS b/CODEOWNERS
index 3660e38..1ea9b56 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -17,9 +17,12 @@
 /perl-package/    @sergeykolychev
 /python/          @szha
 
+# C++ base
+/src/kvstore/     @rahul003
+
 # CMake
-CMakeLists.txt    @szha
-/cmake/           @szha
+CMakeLists.txt    @szha @rahul003
+/cmake/           @szha @rahul003
 
 # MXNet CI
 /tests/ci_build/    @marcoabreu
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index a32e33e..2ba0721 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -163,4 +163,5 @@ List of Contributors
 * [David Braude](https://github.com/dabraude/)
 * [Nick Robinson](https://github.com/nickrobinson)
 * [Kan Wu](https://github.com/wkcn)
+* [Rahul Huilgol](https://github.com/rahul003)
 * [Anirudh Subramanian](https://github.com/anirudh2290/)
diff --git a/Jenkinsfile b/Jenkinsfile
index 3892906..8686012 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -671,6 +671,17 @@ try {
           }
         }
       }
+    },
+    'dist-kvstore tests GPU': {
+      node('mxnetlinux-gpu') {
+        ws('workspace/it-dist-kvstore') {
+          init_git()
+          unpack_lib('gpu')
+          timeout(time: max_time, unit: 'MINUTES') {
+            sh "ci/build.py --nvidiadocker --platform ubuntu_gpu 
/work/runtime_functions.sh integrationtest_ubuntu_gpu_dist_kvstore"
+          }
+        }
+      }
     }
   }
 
diff --git a/amalgamation/Makefile b/amalgamation/Makefile
index 9c45885..f7f3c00 100644
--- a/amalgamation/Makefile
+++ b/amalgamation/Makefile
@@ -51,6 +51,15 @@ endif
 DEFS+=-DMSHADOW_USE_CUDA=0 -DMSHADOW_USE_MKL=0 -DMSHADOW_RABIT_PS=0 
-DMSHADOW_DIST_PS=0 -DDMLC_LOG_STACK_TRACE=0
 DEFS+=-DMSHADOW_FORCE_STREAM -DMXNET_USE_OPENCV=0 -DMXNET_PREDICT_ONLY=1
 CFLAGS=-std=c++11 -Wno-unknown-pragmas -Wall $(DEFS)
+
+# if architecture of the CPU supports F16C instruction set, enable USE_F16C 
for fast fp16 computation on CPU
+ifeq ($(USE_F16C), 1)
+       CFLAGS+=-mf16c
+       DEFS+=-DMSHADOW_USE_F16C=1
+else
+       DEFS+=-DMSHADOW_USE_F16C=0
+endif
+
 ifneq ($(MIN), 1)
        CFLAGS += -I${OPENBLAS_ROOT} -I${OPENBLAS_ROOT}/include
        LDFLAGS+= -L${OPENBLAS_ROOT} -L${OPENBLAS_ROOT}/lib
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 019c4ff..44de137 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -498,6 +498,15 @@ integrationtest_ubuntu_gpu_cpp_package() {
     cpp-package/tests/ci_test.sh
 }
 
+integrationtest_ubuntu_gpu_dist_kvstore() {
+    set -ex
+    export PYTHONPATH=./python/
+    export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    cd tests/nightly/
+    ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py
+    ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py 
--no-multiprecision
+    ../../tools/launch.py -n 7 --launcher local python 
dist_device_sync_kvstore.py
+}
 
 test_ubuntu_cpu_python2() {
     set -ex
diff --git a/make/crosscompile.jetson.mk b/make/crosscompile.jetson.mk
index 9ca4109..31a1398 100644
--- a/make/crosscompile.jetson.mk
+++ b/make/crosscompile.jetson.mk
@@ -132,7 +132,7 @@ endif
 # Settings for power and arm arch
 #----------------------------
 USE_SSE=0
-
+USE_F16C=0
 #----------------------------
 # distributed computing
 #----------------------------
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 88b2e88..e730fd7 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -114,12 +114,13 @@ class Trainer(object):
                 kvstore.set_gradient_compression(self._compression_params)
             if 'dist' in kvstore.type:
                 update_on_kvstore = False
+            if update_on_kvstore:
+                kvstore.set_optimizer(self._optimizer)
+            # optimizer preferably needs to be set before init for 
multiprecision
             for i, param in enumerate(self._params):
                 param_arrays = param.list_data()
                 kvstore.init(i, param_arrays[0])
                 kvstore.pull(i, param_arrays, priority=-i)
-            if update_on_kvstore:
-                kvstore.set_optimizer(self._optimizer)
             self._kvstore = kvstore
             self._update_on_kvstore = update_on_kvstore
         else:
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 5520597..f31dac0 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -83,6 +83,14 @@ def _updater_wrapper(updater):
         updater(key, lhs, rhs)
     return updater_handle
 
+def _get_kvstore_server_command_type(command):
+    command_types = {'kController': 0,
+                     'kSetMultiPrecision': 1,
+                     'kStopServer': 2,
+                     'kSyncMode': 3,
+                     'kSetGradientCompression': 4}
+    assert (command in command_types), "Unknown command type to send to server"
+    return command_types[command]
 
 class KVStore(object):
     """A key-value store for synchronization of values, over multiple 
devices."""
@@ -473,7 +481,11 @@ class KVStore(object):
                 optim_str = py_str(pickle.dumps(optimizer, 0))
             except:
                 raise
-            self._send_command_to_servers(0, optim_str)
+            cmd = _get_kvstore_server_command_type('kController')
+            self._send_command_to_servers(cmd, optim_str)
+            if optimizer.multi_precision:
+                cmd = _get_kvstore_server_command_type('kSetMultiPrecision')
+                self._send_command_to_servers(cmd, '')
         else:
             self._set_updater(opt.get_updater(optimizer))
 
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index 26e885a..ae7726d 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -253,6 +253,8 @@ def _train_multi_device(symbol, ctx, arg_names, 
param_names, aux_names,
 
     if not update_on_kvstore:
         updater = get_updater(optimizer)
+    else:
+        kvstore.set_optimizer(optimizer)
 
     if kvstore:
         _initialize_kvstore(kvstore=kvstore,
@@ -261,9 +263,6 @@ def _train_multi_device(symbol, ctx, arg_names, 
param_names, aux_names,
                             param_names=executor_manager.param_names,
                             update_on_kvstore=update_on_kvstore)
 
-    if update_on_kvstore:
-        kvstore.set_optimizer(optimizer)
-
     # Now start training
     train_data.reset()
     for epoch in range(begin_epoch, end_epoch):
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index 21d9b56..a05c3a3 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -536,15 +536,16 @@ class Module(BaseModule):
         if kvstore:
             if self._compression_params:
                 kvstore.set_gradient_compression(self._compression_params)
+            if update_on_kvstore:
+                kvstore.set_optimizer(self._optimizer)
             # copy initialized local parameters to kvstore
             _initialize_kvstore(kvstore=kvstore,
                                 param_arrays=self._exec_group.param_arrays,
                                 arg_params=self._arg_params,
                                 param_names=self._param_names,
                                 update_on_kvstore=update_on_kvstore)
-        if update_on_kvstore:
-            kvstore.set_optimizer(self._optimizer)
-        else:
+
+        if not update_on_kvstore:
             self._updater = opt.get_updater(optimizer)
 
         self.optimizer_initialized = True
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index afba9ac..373081b 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -47,7 +47,7 @@ class KVStoreDist : public KVStoreLocal {
       : KVStoreLocal(use_device_comm), ps_worker_(nullptr), server_(nullptr) {
     if (IsWorkerNode()) {
       int new_customer_id = GetNewCustomerId();
-      ps_worker_ = new ps::KVWorker<real_t>(0, new_customer_id);
+      ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
       ps::StartAsync(new_customer_id, "mxnet\0");
       if (!ps::Postoffice::Get()->is_recovery()) {
         ps::Postoffice::Get()->Barrier(
@@ -228,17 +228,18 @@ class KVStoreDist : public KVStoreLocal {
           RunContext rctx, Engine::CallbackOnComplete cb) {
         // convert to ps keys
         size_t size = recv_buf.shape().Size();
-
+        const int dtype = recv_buf.dtype();
+        const int num_bytes = mshadow::mshadow_sizeof(dtype);
         PSKV& pskv = (gradient_compression_->get_type() == 
CompressionType::kNone) ?
-                      EncodeDefaultKey(key, size, false) :
-                      EncodeCompressedKey(key, size, false);
-        real_t* data = recv_buf.data().dptr<real_t>();
+                      EncodeDefaultKey(key, size, num_bytes) :
+                      EncodeCompressedKey(key, size, false, num_bytes);
+        char* data = static_cast<char*> (recv_buf.data().dptr_);
         // false means not to delete data when SArray is deleted
-        auto vals = new ps::SArray<real_t>(data, size, false);
+        auto vals = new ps::SArray<char>(data, size * num_bytes, false);
         // issue pull
-        int cmd = (gradient_compression_->get_type() != 
CompressionType::kNone) ?
-                  static_cast<int>(DataHandleType::kCompressedPushPull) :
-                  static_cast<int>(DataHandleType::kDefaultPushPull);
+        RequestType mode = (gradient_compression_->get_type() != 
CompressionType::kNone) ?
+                  RequestType::kCompressedPushPull : 
RequestType::kDefaultPushPull;
+        const int cmd = GetCommandType(mode, dtype);
         CHECK_NOTNULL(ps_worker_)->ZPull(
           pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); 
});
       };
@@ -329,18 +330,21 @@ class KVStoreDist : public KVStoreLocal {
         }
         CopyFromTo(merged, &comm_buf);
       }
-
+      const int dtype = merged.dtype();
+      const int num_bytes = mshadow::mshadow_sizeof(dtype);
       // push to servers
       if (storage_type == kDefaultStorage) {
         if (gradient_compression_->get_type() == CompressionType::kNone) {
-          PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), true);
+          PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), 
num_bytes);
           PushDefault(key, comm_buf, pskv, priority);
         } else {
+          CHECK_EQ(dtype, mshadow::kFloat32) << "Gradient compression is only 
supported for "
+                                             << "float32 type of parameters";
           // Note: gradient compression uses `do_merge` as proxy to
           // detect whether the push is initialization of a key or not.
           // is_active is false when push is initialization of key
           bool is_active = do_merge;
-          PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), 
is_active);
+          PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), 
is_active, num_bytes);
           // Returns push_pskv if active, else pull_pskv
           // we want inactive gc to send uncompressed gradients,
           // but sharded in the same way as later pushes would when gc becomes 
active
@@ -363,25 +367,24 @@ class KVStoreDist : public KVStoreLocal {
   void PushCompressed(int key, const NDArray& comm_buf, const PSKV& pskv, int 
priority) {
     auto &small_buf = compr_buf_[key];
     auto &res_buf = residual_[key];
-    size_t original_size = comm_buf.shape().Size();
+    const size_t original_size = comm_buf.shape().Size();
+    const int dtype = comm_buf.dtype();
 
     // Init the small buffer and residual_ buffer for quantize
     if (small_buf.is_none()) {
-      small_buf = NDArray(TShape{pskv.size}, comm_buf.ctx(), false, 
comm_buf.dtype());
-      res_buf = NDArray(TShape{(int64_t) original_size}, comm_buf.ctx(),
-                        false, comm_buf.dtype());
+      small_buf = NDArray(TShape{pskv.size}, comm_buf.ctx(), false, dtype);
+      res_buf = NDArray(TShape{static_cast<int64_t>(original_size)}, 
comm_buf.ctx(), false, dtype);
       res_buf = 0;
     }
     gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority);
     auto push_to_servers =
-      [this, key, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete 
cb) {
-        size_t size = small_buf.shape().Size();
-        real_t* data = small_buf.data().dptr<real_t>();
+      [this, key, dtype, pskv, small_buf](RunContext rctx, 
Engine::CallbackOnComplete cb) {
+        size_t size = small_buf.shape().Size() * 
mshadow::mshadow_sizeof(dtype);
+        char* data = static_cast<char *> (small_buf.data().dptr_);
         // do push. false means no delete
-        ps::SArray<real_t> vals(data, size, false);
-        CHECK_NOTNULL(ps_worker_)->ZPush(
-          pskv.keys, vals, pskv.lens,
-          static_cast<int>(DataHandleType::kCompressedPushPull), [cb]() { 
cb(); });
+        ps::SArray<char> vals(data, size, false);
+        int cmd = GetCommandType(RequestType::kCompressedPushPull, dtype);
+        CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, 
[cb]() { cb(); });
       };
     // acquire locks on both comm_buf and small_buf so that
     // pull (which uses comm_buf) for the same key waits till push finishes
@@ -398,14 +401,16 @@ class KVStoreDist : public KVStoreLocal {
   void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int 
priority) {
     auto push_to_servers =
         [this, key, pskv, send_buf](RunContext rctx, 
Engine::CallbackOnComplete cb) {
+          const int dtype = send_buf.dtype();
           // convert to ps keys
-          size_t size = send_buf.shape().Size();
-          real_t* data = send_buf.data().dptr<real_t>();
+          const size_t size = send_buf.shape().Size() * 
mshadow::mshadow_sizeof(dtype);
+          char* data = static_cast<char *>(send_buf.data().dptr_);
           // do push. false means no delete
-          ps::SArray<real_t> vals(data, size, false);
+          ps::SArray<char> vals(data, size, false);
+          int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
           CHECK_NOTNULL(ps_worker_)->ZPush(
               pskv.keys, vals, pskv.lens,
-              static_cast<int>(DataHandleType::kDefaultPushPull), [cb]() { 
cb(); });
+              cmd, [cb]() { cb(); });
         };
     Engine::Get()->PushAsync(
         push_to_servers,
@@ -422,23 +427,22 @@ class KVStoreDist : public KVStoreLocal {
     using namespace rowsparse;
     auto push_to_servers = [this, key, send_buf]
                            (RunContext rctx, Engine::CallbackOnComplete cb) {
-      real_t* data = send_buf.data().dptr<real_t>();
+      char* data = static_cast<char *>(send_buf.data().dptr_);
       const int64_t num_rows = send_buf.aux_shape(kIdx)[0];
       const auto offsets = send_buf.aux_data(kIdx).dptr<int64_t>();
       const auto unit_len = send_buf.shape().ProdShape(1, 
send_buf.shape().ndim());
+      const int num_bytes = mshadow::mshadow_sizeof(send_buf.dtype());
       const int64_t size = num_rows * unit_len;
-
        // convert to ps keys in row sparse format
       PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
-                                      unit_len, send_buf.shape()[0]);
+                                      unit_len, send_buf.shape()[0], 
num_bytes);
       if (this->log_verbose_) {
         LOG(INFO) << "worker " << get_rank() << " push lens: " << pskv.lens << 
" keys: "
                   << pskv.keys << " size: " << size;
       }
-      ps::SArray<real_t> vals(data, size, false);
-      CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens,
-                                       
static_cast<int>(DataHandleType::kRowSparsePushPull),
-                                       [cb]() { cb(); });
+      ps::SArray<char> vals(data, size * num_bytes, false);
+      const int cmd = GetCommandType(RequestType::kRowSparsePushPull, 
send_buf.dtype());
+      CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() 
{ cb(); });
     };
     Engine::Get()->PushAsync(
         push_to_servers,
@@ -460,27 +464,31 @@ class KVStoreDist : public KVStoreLocal {
       // allocate memory for the buffer
       CHECK_EQ(indices.dtype(), mshadow::kInt64);
       const TBlob idx_data = indices.data();
-      size_t num_rows = idx_data.shape_.Size();
+      const size_t num_rows = idx_data.shape_.Size();
       recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
-      real_t* data = recv_buf.data().dptr<real_t>();
+      const int dtype = recv_buf.dtype();
+      char* data = static_cast<char *>(recv_buf.data().dptr_);
       const auto offsets = idx_data.dptr<int64_t>();
       const auto unit_len = recv_buf.shape().ProdShape(1, 
recv_buf.shape().ndim());
       const int64_t size = num_rows * unit_len;
+      const int num_bytes = mshadow::mshadow_sizeof(dtype);
       // convert to ps keys in row sparse format
       PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
-                                      unit_len, recv_buf.shape()[0]);
+                                      unit_len, recv_buf.shape()[0],
+                                      num_bytes);
       if (this->log_verbose_) {
         LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << 
" keys: "
                   << pskv.keys << " size: " << size;
       }
-      auto vals = new ps::SArray<real_t>(data, size, false);
+      auto vals = new ps::SArray<char>(data, size * num_bytes, false);
+      const int cmd = GetCommandType(RequestType::kRowSparsePushPull, 
recv_buf.dtype());
       // copy indices to recv_buf. this needs to be done before ZPull
       // because after pull is done, the callback function returns and locks 
are released.
       // at this point, later functions may access the indices variable while 
copy happens
       mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
                     idx_data.FlatTo1D<cpu, int64_t>());
       CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens,
-                                       
static_cast<int>(DataHandleType::kRowSparsePushPull),
+                                       cmd,
                                        [vals, cb]() { delete vals; cb(); });
     };
     CHECK_NOTNULL(Engine::Get())->PushAsync(
@@ -504,67 +512,82 @@ class KVStoreDist : public KVStoreLocal {
   }
 
   /**
-   * \brief convert to keys in ps
+   * \brief convert to pskv for parameter server
+   * \param key
+   * \param num_arr_elems number of elements in the value for key
+   * \param num_bytes size of each element in number of bytes
+   * \return PSKV used for both push and pull
    */
-  inline PSKV& EncodeDefaultKey(int key, size_t size, bool is_push) {
+  inline PSKV& EncodeDefaultKey(const int key, const size_t num_arr_elems,
+                                const int num_bytes) {
     mu_.lock();
     PSKV& pskv = ps_kv_[key];
     mu_.unlock();
+    size_t pskv_size = num_arr_elems * num_bytes;
     if (!pskv.keys.empty()) {
-      CHECK_EQ(static_cast<size_t>(pskv.size), size) << "The value size cannot 
be changed";
+      CHECK_EQ(static_cast<size_t>(pskv.size), pskv_size)
+        << "The value size cannot be changed " << pskv_size << ". Key is " << 
key;
     } else {
       auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
-      int num_servers = krs.size();
+      const int num_servers = krs.size();
       CHECK_GT(num_servers, 0);
 
       // a simple heuristic for load balance
-      if (size < bigarray_bound_) {
+      if (num_arr_elems < bigarray_bound_) {
         // send it to a single random picked server
         int server = (key * 9973) % num_servers;
         ps::Key ps_key = krs[server].begin() + key;
         CHECK_LT(ps_key, krs[server].end());
         pskv.keys.push_back(ps_key);
-        pskv.lens.push_back(size);
-        pskv.size = size;
+        const int total_bytes = num_arr_elems * num_bytes;
+        pskv.lens.push_back(total_bytes);
+        pskv.size = total_bytes;
       } else {
         // parition it to all servers
         pskv.size = 0;
         for (int i = 0; i < num_servers; ++i) {
           size_t part_size =
-            
static_cast<size_t>(round(static_cast<double>(size)/num_servers*(i+1))) -
-            
static_cast<size_t>(round(static_cast<double>(size)/num_servers*i));
+            
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*(i+1)))
 -
+            
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*i));
           ps::Key ps_key = krs[i].begin() + key;
           CHECK_LT(ps_key, krs[i].end());
           pskv.keys.push_back(ps_key);
-          pskv.lens.push_back(part_size);
-          pskv.size += part_size;
+          const int total_bytes = part_size * num_bytes;
+          pskv.lens.push_back(total_bytes);
+          pskv.size += total_bytes;
         }
-        CHECK_EQ(static_cast<size_t>(pskv.size), size);
       }
+      CHECK_EQ(static_cast<size_t>(pskv.size), pskv_size);
     }
     return pskv;
   }
 
   /**
-   * \brief Convert to keys in ps for compressed values
-   * Divides original array into equal parts for each server
-   * Populates both push and pull pskv on first call
+   * \brief Convert to PSKV for pushes and pulls when gradient compression is 
used.
+   * Divides original array into equal parts for each server.
+   * Populates both push and pull pskv on first call.
+   * \param key
+   * \param num_arr_elems number of elements in the value for key
+   * \param is_push whether this is push or pull
+   * \param num_bytes size of each element in number of bytes
+   * \return PSKV used for both push and pull
    */
-  inline PSKV& EncodeCompressedKey(int key, size_t original_size, bool 
is_push) {
+  inline PSKV& EncodeCompressedKey(const int key, const size_t 
original_num_elem,
+                                   const bool is_push, const int num_bytes) {
     auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
-    int num_servers = krs.size();
+    const int num_servers = krs.size();
     CHECK_GT(num_servers, 0);
 
     // represents size of data to be sent
-    size_t compr_size = 
gradient_compression_->GetCompressedSize(original_size);
-
+    size_t compr_num_elem = 
gradient_compression_->GetCompressedSize(original_num_elem);
     mu_.lock();
     PSKV& pskv = (is_push) ? compr_ps_kv_[key].push : compr_ps_kv_[key].pull;
     mu_.unlock();
 
     if (!pskv.keys.empty()) {
-      size_t size = (is_push) ? compr_size : original_size;
-      CHECK_EQ(static_cast<size_t >(pskv.size), size)<< "The value size can't 
be changed";
+      const size_t num_elem = (is_push) ? compr_num_elem : original_num_elem;
+      CHECK_EQ(static_cast<size_t >(pskv.size), num_elem * num_bytes)
+        << "The value size can't be changed. For key " << key;
     } else {
       // populate both pull and push pskvs
       // push pskv has sizes corresponding to compressed data
@@ -574,18 +597,20 @@ class KVStoreDist : public KVStoreLocal {
       PSKV& push_pskv = compr_ps_kv_[key].push;
       mu_.unlock();
 
-      if (original_size < bigarray_bound_) {
+      if (original_num_elem < bigarray_bound_) {
         // a simple heuristic for load balancing
         // send it to a single random picked server
-        int server = (key * 9973) % num_servers;
+        const int server = (key * 9973) % num_servers;
         ps::Key ps_key = krs[server].begin() + key;
         CHECK_LT(ps_key, krs[server].end());
         // meta info
-        push_pskv.keys.push_back(krs[server].begin() + original_size);
+        push_pskv.keys.push_back(krs[server].begin() + original_num_elem);
         push_pskv.lens.push_back(0);
         // data
         push_pskv.keys.push_back(ps_key);
         pull_pskv.keys.push_back(ps_key);
+        const int compr_size = compr_num_elem * num_bytes;
+        const int original_size = original_num_elem * num_bytes;
         push_pskv.lens.push_back(compr_size);
         pull_pskv.lens.push_back(original_size);
         push_pskv.size = compr_size;
@@ -598,12 +623,12 @@ class KVStoreDist : public KVStoreLocal {
         for (int i = 0; i < num_servers; ++i) {
           size_t part_compr, part_orig;
           if (i == num_servers-1) {
-            part_compr = compr_size - push_pskv.size;
-            part_orig = original_size - pull_pskv.size;
+            part_compr = compr_num_elem - push_pskv.size;
+            part_orig = original_num_elem - pull_pskv.size;
           } else {
             part_compr =
-              static_cast<size_t> 
(round(static_cast<double>(compr_size)/num_servers*(i+1))) -
-              static_cast<size_t> 
(round(static_cast<double>(compr_size)/num_servers*(i)));
+              static_cast<size_t> 
(round(static_cast<double>(compr_num_elem)/num_servers*(i+1))) -
+              static_cast<size_t> 
(round(static_cast<double>(compr_num_elem)/num_servers*(i)));
             part_orig = part_compr * 
gradient_compression_->GetCompressionFactor();
           }
 
@@ -618,25 +643,27 @@ class KVStoreDist : public KVStoreLocal {
           CHECK_LT(ps_key, krs[i].end());
           push_pskv.keys.push_back(ps_key);
           pull_pskv.keys.push_back(ps_key);
-          // push_pskv stores lengths of compressed blocks
-          push_pskv.lens.push_back(part_compr);
-          // pull_pskv stores lengths of original data
-          pull_pskv.lens.push_back(part_orig);
+          push_pskv.lens.push_back(part_compr * num_bytes);
+          pull_pskv.lens.push_back(part_orig * num_bytes);
+          // num elements need to be inserted below so that for last server,
+          // there is no round off error
           push_pskv.size += part_compr;
           pull_pskv.size += part_orig;
         }
-        CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_size);
-        CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_size);
-        CHECK_EQ(push_pskv.lens.size(), num_servers*2);
+        CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_num_elem);
+        CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_num_elem);
+        push_pskv.size *= num_bytes;
+        pull_pskv.size *= num_bytes;
+        CHECK_EQ(push_pskv.lens.size(), num_servers * 2);
         }
       }
     return pskv;
   }
 
   // Note: this encoding method for row sparse keys doesn't allow cross-layer 
batching
-  inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const 
int64_t num_rows,
+  inline PSKV& EncodeRowSparseKey(const int key, const int64_t num_elem, const 
int64_t num_rows,
                                   const int64_t *offsets, const size_t 
unit_len,
-                                  const int64_t total_num_rows) {
+                                  const int64_t total_num_rows, const int 
num_bytes) {
     using namespace common;
     mu_.lock();
     PSKV& pskv = ps_kv_[key];
@@ -645,7 +672,7 @@ class KVStoreDist : public KVStoreLocal {
     pskv.lens.clear();
     // TODO(haibin) cache this information
     auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
-    int num_servers = krs.size();
+    const int num_servers = krs.size();
     CHECK_GT(num_servers, 0);
 
     if (total_num_rows * unit_len >= bigarray_bound_) {
@@ -656,7 +683,7 @@ class KVStoreDist : public KVStoreLocal {
         ps::Key master_key = krs[i].begin() + key;
         pskv.keys.push_back(master_key);
         pskv.lens.push_back(0);
-        if (offsets && size > 0) {
+        if (offsets && num_elem > 0) {
           // calculate partition ranges
           int64_t part_num_rows =
             llround(static_cast<double>(total_num_rows) / num_servers * (i + 
1)) -
@@ -669,16 +696,17 @@ class KVStoreDist : public KVStoreLocal {
             ps::Key ps_key = krs[i].begin() + key + (*offset - start_row);
             CHECK_LT(ps_key, krs[i].end());
             pskv.keys.push_back(ps_key);
-            pskv.lens.push_back(unit_len);
-            pskv.size += unit_len;
+            const int part_size = unit_len * num_bytes;
+            pskv.lens.push_back(part_size);
+            pskv.size += (part_size);
           }
           start_row = end_row;
         }
       }
-      CHECK_EQ(static_cast<size_t>(pskv.size), size);
+      CHECK_EQ(static_cast<size_t>(pskv.size), num_elem * num_bytes);
     } else {
       // send it to a single random picked server
-      int server = (key * 9973) % num_servers;
+      const int server = (key * 9973) % num_servers;
       ps::Key master_key = krs[server].begin() + key;
       pskv.keys.push_back(master_key);
       pskv.lens.push_back(0);
@@ -686,9 +714,9 @@ class KVStoreDist : public KVStoreLocal {
         ps::Key ps_key = krs[server].begin() + key + offsets[i];
         CHECK_LT(ps_key, krs[server].end());
         pskv.keys.push_back(ps_key);
-        pskv.lens.push_back(unit_len);
+        pskv.lens.push_back(unit_len * num_bytes);
       }
-      pskv.size = size;
+      pskv.size = num_elem * num_bytes;
     }
     return pskv;
   }
@@ -696,7 +724,7 @@ class KVStoreDist : public KVStoreLocal {
   /**
    * \brief for worker to push and pull data
    */
-  ps::KVWorker<real_t>* ps_worker_;
+  ps::KVWorker<char>* ps_worker_;
   /**
    * \brief the server handle
    */
diff --git a/src/kvstore/kvstore_dist_server.h 
b/src/kvstore/kvstore_dist_server.h
index c2ddcd8..421de27 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -40,14 +40,52 @@
 namespace mxnet {
 namespace kvstore {
 
+// maintain same order in frontend.
 enum class CommandType {
-  kController, kStopServer, kSyncMode, kSetGradientCompression
+  kController, kSetMultiPrecision, kStopServer, kSyncMode, 
kSetGradientCompression,
 };
 
-enum class DataHandleType {
-  kDefaultPushPull, kCompressedPushPull, kRowSparsePushPull
+enum class RequestType {
+  kDefaultPushPull, kRowSparsePushPull, kCompressedPushPull
 };
 
+struct DataHandleType {
+  RequestType requestType;
+  int dtype;
+};
+
+/*!
+ * Uses Cantor pairing function to generate a unique number given two numbers.
+ * This number can also be inverted to find the unique pair whose Cantor value 
is this number.
+ * Ref: https://en.wikipedia.org/wiki/Pairing_function#Cantor_pairing_function
+ * \param requestType RequestType
+ * \param dtype integer
+ * \return Cantor value of arguments
+ */
+static int GetCommandType(RequestType requestType, int d) {
+  int m = static_cast<int>(requestType);
+  return (((m + d) * (m + d + 1)) / 2) + d;
+}
+
+/*!
+ * Unpairs Cantor value and finds the two integers used to pair.
+ * Then returns DataHandleType object with those numbers.
+ * \param cmd DataHandleCommand generated by GetCommandType function
+ * \return DataHandleType
+ */
+static DataHandleType DepairDataHandleType(int cmd) {
+  int w = std::floor((std::sqrt(8 * cmd + 1) - 1)/2);
+  int t = ((w * w) + w) / 2;
+  int y = cmd - t;
+  int x = w - y;
+  CHECK_GE(x, 0);
+  CHECK_GE(y, 0);
+  DataHandleType type;
+  type.requestType = static_cast<RequestType>(x);
+  type.dtype = y;
+  return type;
+}
+
 /**
  * \brief executor runs a function using the thread called \ref Start
  */
@@ -114,7 +152,7 @@ class KVStoreDistServer {
  public:
   KVStoreDistServer() {
     using namespace std::placeholders;
-    ps_server_ = new ps::KVServer<float>(0);
+    ps_server_ = new ps::KVServer<char>(0);
     static_cast<ps::SimpleApp*>(ps_server_)->set_request_handle(
         std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2));
     ps_server_->set_request_handle(
@@ -146,9 +184,11 @@ class KVStoreDistServer {
   }
 
  private:
-  struct MergeBuf {
+  struct UpdateBuf {
     std::vector<ps::KVMeta> request;
-    NDArray array;
+    NDArray merged;
+    // temp_array is used to cast received values as float32 for computation 
if required
+    NDArray temp_array;
   };
 
   void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
@@ -159,54 +199,114 @@ class KVStoreDistServer {
       sync_mode_ = true;
     } else if (recved_type == CommandType::kSetGradientCompression) {
       gradient_compression_->DecodeParams(recved.body);
-    } else {
-      // this uses value 0 for message id from frontend
+    } else if (recved_type == CommandType::kSetMultiPrecision) {
+      // uses value 1 for message id from frontend
+      if (!multi_precision_) {
+        multi_precision_ = true;
+        CreateMultiPrecisionCopies();
+      }
+    } else if (recved_type == CommandType::kController) {
+      // value of 0
       // let the main thread to execute ctrl, which is necessary for python
       exec_.Exec([this, recved]() {
           CHECK(controller_);
           controller_(recved.head, recved.body);
         });
+    } else {
+      LOG(FATAL) << "Unknown command type received " << recved.head;
     }
     app->Response(recved);
   }
 
+  /*
+   * For keys already initialized, if necessary create stored_realt.
+   * This will only be used if by some wrong usage of kvstore,
+   * some keys are initialized before optimizer is set.
+   */
+  void CreateMultiPrecisionCopies() {
+    for (auto const& stored_entry : store_) {
+      const int key = stored_entry.first;
+      const NDArray& stored = stored_entry.second;
+      if (stored.dtype() != mshadow::kFloat32) {
+        auto& stored_realt = store_realt_[key];
+        if (stored.storage_type() == kRowSparseStorage) {
+          stored_realt = NDArray(kRowSparseStorage, stored.shape(), 
stored.ctx(),
+                                 true, mshadow::kFloat32);
+        } else {
+          stored_realt = NDArray(stored.shape(), stored.ctx(), false, 
mshadow::kFloat32);
+        }
+
+        auto& update = update_buf_[key];
+        if (!update.merged.is_none()) {
+          if (update.merged.storage_type() == kRowSparseStorage) {
+            update.merged = NDArray(kRowSparseStorage, update.merged.shape(), 
update.merged.ctx(),
+                                    true, mshadow::kFloat32);
+          } else {
+            update.merged = NDArray(update.merged.shape(), 
update.merged.ctx(), false,
+                                    mshadow::kFloat32);
+          }
+        }
+        CHECK(update.request.size() == 0)
+          << ps::MyRank() << "Multiprecision mode can not be set while pushes 
are underway."
+          << "Please set optimizer before pushing keys." << key << " " << 
update.request.size();
+
+        CopyFromTo(stored, stored_realt);
+      }
+    }
+    for (auto const& stored_realt_entry : store_realt_) {
+      stored_realt_entry.second.WaitToRead();
+    }
+  }
+
   void DataHandleEx(const ps::KVMeta& req_meta,
-                    const ps::KVPairs<real_t>& req_data,
-                    ps::KVServer<real_t>* server) {
-    DataHandleType recved_type = static_cast<DataHandleType>(req_meta.cmd);
-    if (recved_type == DataHandleType::kRowSparsePushPull) {
-      DataHandleRowSparse(req_meta, req_data, server);
-    } else if (recved_type == DataHandleType::kCompressedPushPull) {
-      DataHandleCompressed(req_meta, req_data, server);
-    } else {
-      DataHandleDefault(req_meta, req_data, server);
+                    const ps::KVPairs<char>& req_data,
+                    ps::KVServer<char>* server) {
+    DataHandleType type = DepairDataHandleType(req_meta.cmd);
+    switch (type.requestType) {
+      case RequestType::kRowSparsePushPull:
+        DataHandleRowSparse(type, req_meta, req_data, server);
+        break;
+      case RequestType::kCompressedPushPull:
+        DataHandleCompressed(type, req_meta, req_data, server);
+        break;
+      case RequestType::kDefaultPushPull:
+        DataHandleDefault(type, req_meta, req_data, server);
+        break;
     }
-    return;
   }
 
-  inline void ApplyUpdates(const int key, MergeBuf *merged, NDArray *stored,
-                           ps::KVServer<real_t>* server) {
-    if (merged->request.size() == (size_t) ps::NumWorkers()) {
+  inline bool has_multi_precision_copy(const DataHandleType type) {
+    return multi_precision_ && type.dtype != mshadow::kFloat32;
+  }
+
+  inline void ApplyUpdates(const DataHandleType type, const int key,
+                           UpdateBuf *update_buf, ps::KVServer<char>* server) {
+    if (!sync_mode_ || update_buf->request.size() == (size_t) 
ps::NumWorkers()) {
       // let the main thread to execute updater_, which is necessary for python
+      auto& stored = has_multi_precision_copy(type) ? store_realt_[key] : 
store_[key];
+      auto& update =  sync_mode_ ? update_buf->merged : update_buf->temp_array;
       if (updater_) {
-        exec_.Exec([this, key, merged, stored](){
-            CHECK(updater_);
-            updater_(key, merged->array, stored);
-          });
+        exec_.Exec([this, key, &update, &stored](){
+          CHECK(updater_);
+          updater_(key, update, &stored);
+        });
       } else {
+        CHECK(sync_mode_) << "Updater needs to be set for async mode";
         // if no updater, just copy
-        CopyFromTo(merged->array, stored);
+        CopyFromTo(update_buf->merged, &stored);
       }
+
       if (log_verbose_)  {
-        LOG(INFO) << "sync response to " << merged->request.size() << " 
workers";
+        LOG(INFO) << "sent response to " << update_buf->request.size() << " 
workers";
       }
-      for (const auto& req : merged->request) {
+      for (const auto& req : update_buf->request) {
         server->Response(req);
       }
-      merged->request.clear();
-      stored->WaitToRead();
+      update_buf->request.clear();
+      if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
+      stored.WaitToRead();
     } else {
-      merged->array.WaitToRead();
+      update_buf->merged.WaitToRead();
     }
   }
 
@@ -220,175 +320,229 @@ class KVStoreDistServer {
     }
   }
 
-  void DataHandleRowSparse(const ps::KVMeta& req_meta,
-                       const ps::KVPairs<real_t>& req_data,
-                       ps::KVServer<real_t>* server) {
+  void AccumulateRowSparseGrads(const DataHandleType type,
+                                const NDArray& recved,
+                                UpdateBuf* updateBuf) {
+    NDArray out(kRowSparseStorage, updateBuf->merged.shape(), Context(), true,
+                has_multi_precision_copy(type) ? mshadow::kFloat32 : 
type.dtype);
+    if (has_multi_precision_copy(type)) CopyFromTo(recved, 
updateBuf->temp_array);
+    const NDArray& to_merge = has_multi_precision_copy(type) ? 
updateBuf->temp_array : recved;
+    // accumulate row_sparse gradients
+    // TODO(haibin) override + operator for row_sparse NDArray
+    // instead of calling BinaryComputeRspRsp directly
+    using namespace mshadow;
+    Engine::Get()->PushAsync(
+    [to_merge, updateBuf, out](RunContext ctx, Engine::CallbackOnComplete 
on_complete) {
+      op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>(
+      {}, {}, {to_merge, updateBuf->merged}, {kWriteTo}, {out});
+      on_complete();
+    }, to_merge.ctx(), {to_merge.var(), updateBuf->merged.var()}, {out.var()},
+    FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+    CopyFromTo(out, &(updateBuf->merged), 0);
+    updateBuf->merged.WaitToRead();
+  }
+
+  void RowSparsePullResponse(const DataHandleType type,
+                             const int master_key,
+                             const size_t num_rows,
+                             const ps::KVMeta& req_meta,
+                             const ps::KVPairs<char>& req_data,
+                             ps::KVServer<char>* server) {
+    if (log_verbose_) LOG(INFO) << "pull: " << master_key;
+    ps::KVPairs<char> response;
+    if (num_rows == 0) {
+      std::vector<int> lens(req_data.keys.size(), 0);
+      response.keys = req_data.keys;
+      response.lens.CopyFrom(lens.begin(), lens.end());
+      server->Response(req_meta, response);
+      return;
+    }
+    const NDArray& stored = store_[master_key];
+    if (has_multi_precision_copy(type)) stored.WaitToRead();
+    CHECK(!stored.is_none()) << "init " << master_key << " first";
+    auto shape = stored.shape();
+    auto unit_len = shape.ProdShape(1, shape.ndim());
+    const int num_bytes = mshadow::mshadow_sizeof(type.dtype);
+    const int unit_size = unit_len * num_bytes;
+    const char* data = static_cast<char *> (stored.data().dptr_);
+    auto len = num_rows * unit_size;
+    // concat values
+    response.vals.resize(len);
+    #pragma omp parallel for
+    for (size_t i = 1; i <= num_rows; i++) {
+      int key = DecodeKey(req_data.keys[i]);
+      int64_t row_id = key - master_key;
+      const auto src = data + row_id * unit_size;
+      auto begin = (i - 1) * unit_size;
+      auto end = i * unit_size;
+      response.vals.segment(begin, end).CopyFrom(src, unit_size);
+    }
+    // setup response
+    response.keys = req_data.keys;
+    std::vector<int> lens(req_data.keys.size(), unit_len);
+    lens[0] = 0;
+    response.lens.CopyFrom(lens.begin(), lens.end());
+    server->Response(req_meta, response);
+  }
+
+  void InitRowSparseStored(const DataHandleType type,
+                           const int master_key,
+                           const size_t num_rows,
+                           const ps::KVMeta& req_meta,
+                           const ps::KVPairs<char>& req_data,
+                           ps::KVServer<char>* server) {
+    auto& stored = has_multi_precision_copy(type) ? store_realt_[master_key] : 
store_[master_key];
+    int dtype = type.dtype;
+    int num_bytes = mshadow::mshadow_sizeof(dtype);
+    auto unit_len = req_data.lens[1] / num_bytes;
+    CHECK_GT(unit_len, 0);
+    size_t ds[] = {num_rows, (size_t) unit_len};
+    TShape dshape(ds, ds + 2);
+    CHECK_EQ(req_data.vals.size(), num_rows * unit_len * num_bytes);
+    TBlob recv_blob;
+    MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+      recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), 
dshape, cpu::kDevMask);
+    })
+    NDArray recved = NDArray(recv_blob, 0);
+    stored = NDArray(kRowSparseStorage, dshape, Context(), true,
+                     has_multi_precision_copy(type) ? mshadow::kFloat32 : 
type.dtype);
+    if (has_multi_precision_copy(type)) {
+      store_[master_key] = NDArray(kRowSparseStorage, dshape, Context(), true, 
type.dtype);
+    }
+    Engine::Get()->PushAsync(
+    [this, recved, stored, type](RunContext ctx, Engine::CallbackOnComplete 
on_complete) {
+      NDArray rsp = stored;
+      stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
+      mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+      using namespace mxnet::op;
+      nnvm::dim_t nnr = rsp.shape()[0];
+      MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
+        IType* idx = rsp.aux_data(rowsparse::kIdx).dptr<IType>();
+        mxnet_op::Kernel<PopulateFullIdxRspKernel, cpu>::Launch(s, nnr, idx);
+      });
+      TBlob rsp_data = rsp.data();
+      // copies or casts as appropriate
+      ndarray::Copy<cpu, cpu>(recved.data(), &rsp_data, Context(), Context(), 
RunContext());
+      on_complete();
+    }, recved.ctx(), {recved.var()}, {stored.var()},
+    FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+    if (has_multi_precision_copy(type)) {
+      CopyFromTo(stored, store_[master_key]);
+      store_[master_key].WaitToRead();
+    }
+    stored.WaitToRead();
+    server->Response(req_meta);
+  }
+
+  void DataHandleRowSparse(const DataHandleType type, const ps::KVMeta& 
req_meta,
+                           const ps::KVPairs<char>& req_data,
+                           ps::KVServer<char>* server) {
     int master_key = DecodeKey(req_data.keys[0]);
     auto num_rows = req_data.keys.size() - 1;
     auto& stored = store_[master_key];
     if (req_meta.push) {
       CHECK_GT(req_data.lens.size(), 0) << "req_data.lens cannot be empty";
       CHECK_EQ(req_data.lens[0], 0);
-      real_t* data = req_data.vals.data();
       if (stored.is_none()) {
         if (log_verbose_) LOG(INFO) << "initial push: " << master_key;
         // initialization
         CHECK_GT(num_rows, 0) << "init with empty data is not supported";
-        auto unit_len = req_data.lens[1];
-        CHECK_GT(unit_len, 0);
-        size_t ds[] = {num_rows, (size_t) unit_len};
-        TShape dshape(ds, ds + 2);
-        CHECK_EQ(req_data.vals.size(), num_rows * unit_len);
-        TBlob recv_blob(data, dshape, cpu::kDevMask);  // NOLINT(*)
-        NDArray recved = NDArray(recv_blob, 0);
-        stored = NDArray(kRowSparseStorage, dshape, Context());
-        Engine::Get()->PushAsync(
-          [recved, stored](RunContext ctx, Engine::CallbackOnComplete 
on_complete) {
-            NDArray rsp = stored;
-            stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
-            mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
-            using namespace mxnet::op;
-            nnvm::dim_t nnr = rsp.shape()[0];
-            MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
-              IType* idx = rsp.aux_data(rowsparse::kIdx).dptr<IType>();
-              mxnet_op::Kernel<PopulateFullIdxRspKernel, cpu>::Launch(s, nnr, 
idx);
-            });
-            mshadow::Copy(rsp.data().FlatTo1D<cpu, float>(),
-                          recved.data().FlatTo1D<cpu, float>(), s);
-            on_complete();
-          }, recved.ctx(), {recved.var()}, {stored.var()},
-          FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
-        stored.WaitToRead();
-        server->Response(req_meta);
+        InitRowSparseStored(type, master_key, num_rows, req_meta, req_data, 
server);
         return;
-      }
-      // synced push
-      if (sync_mode_) {
-        if (log_verbose_) LOG(INFO) << "sync push: " << master_key << " " << 
req_data.keys;
-        auto& merged = merge_buf_[master_key];
-        if (merged.array.is_none()) {
-          merged.array = NDArray(kRowSparseStorage, stored.shape(), Context());
+      } else {
+        if (log_verbose_) LOG(INFO) << "push: " << master_key << " " << 
req_data.keys;
+        auto& updates = update_buf_[master_key];
+        if (sync_mode_ && updates.merged.is_none()) {
+          updates.merged = NDArray(kRowSparseStorage, stored.shape(), 
Context(), true,
+                                   has_multi_precision_copy(type) ? 
mshadow::kFloat32 : type.dtype);
         }
+        if (has_multi_precision_copy(type) && updates.temp_array.is_none()) {
+          updates.temp_array = NDArray(kRowSparseStorage, stored.shape(), 
Context(), false,
+                                       mshadow::kFloat32);
+        }
+
         if (num_rows == 0) {
-          // reset to zeros
-          if (merged.request.size() == 0) {
-            merged.array = NDArray(kRowSparseStorage, stored.shape(), 
Context());
+          if (sync_mode_) {
+            if (updates.request.empty()) {
+              // reset to zeros
+              int merged_dtype = has_multi_precision_copy(type) ? 
mshadow::kFloat32 : type.dtype;
+              updates.merged = NDArray(kRowSparseStorage, stored.shape(), 
Context(),
+                                       true, merged_dtype);
+            }  // else nothing to aggregate
+            updates.request.push_back(req_meta);
+            ApplyUpdates(type, master_key, &updates, server);
           } else {
-            // nothing to aggregate
+            server->Response(req_meta);
           }
-          merged.request.push_back(req_meta);
-          ApplyUpdates(master_key, &merged,  &stored, server);
-          return;
-        }
-        auto unit_len = req_data.lens[1];
-        CHECK_GT(unit_len, 0);
-        // indices
-        std::vector<int64_t> indices(num_rows);
-        DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);
-        // data
-        TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), 
cpu::kDevMask);
-        size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
-        TShape dshape(ds, ds + 2);
-        TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*)
-        // row_sparse NDArray
-        NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, 
{idx_blob}, 0);
-
-        if (merged.request.size() == 0) {
-          CopyFromTo(recved, &merged.array, 0);
         } else {
-          NDArray out(kRowSparseStorage, stored.shape(), Context());
-          // accumulate row_sparse gradients
-          // TODO(haibin) override + operator for row_sparse NDArray
-          // instead of calling BinaryComputeRspRsp directly
-          using namespace mshadow;
-          Engine::Get()->PushAsync(
-            [recved, merged, out](RunContext ctx, Engine::CallbackOnComplete 
on_complete) {
-              op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>(
-                {}, {}, {recved, merged.array}, {kWriteTo}, {out});
-              on_complete();
-            }, recved.ctx(), {recved.var(), merged.array.var()}, {out.var()},
-            FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
-          CopyFromTo(out, &merged.array, 0);
-        }
-        merged.request.push_back(req_meta);
-        ApplyUpdates(master_key, &merged,  &stored, server);
-      } else {
-        // async push
-        if (log_verbose_) LOG(INFO) << "async push: " << master_key;
-        if (num_rows == 0) {
-          server->Response(req_meta);
-          return;
+          auto unit_len = req_data.lens[1] / 
mshadow::mshadow_sizeof(type.dtype);
+          CHECK_GT(unit_len, 0);
+          // indices
+          std::vector<int64_t> indices(num_rows);
+          DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);
+
+          // data
+          TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), 
cpu::kDevMask);
+          size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
+          TShape dshape(ds, ds + 2);
+          TBlob recv_blob;
+          MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
+            recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()),
+                              dshape, cpu::kDevMask);
+          })
+          // row_sparse NDArray
+          NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, 
{idx_blob}, 0);
+
+          if (updates.request.empty()) {
+            if (sync_mode_) {
+              CopyFromTo(recved, updates.merged);
+            } else {
+              if (has_multi_precision_copy(type)) {
+                CopyFromTo(recved, updates.temp_array);
+              } else {
+                updates.temp_array = recved;
+              }
+            }
+          } else {
+            CHECK(sync_mode_);
+            AccumulateRowSparseGrads(type, recved, &updates);
+          }
+          updates.request.push_back(req_meta);
+          ApplyUpdates(type, master_key, &updates, server);
         }
-        auto unit_len = req_data.lens[1];
-        CHECK_GT(unit_len, 0);
-        // indices
-        std::vector<int64_t> indices(num_rows);
-        DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);
-        TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), 
cpu::kDevMask);
-        size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
-        TShape dshape(ds, ds + 2);
-        TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*)
-        NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, 
{idx_blob}, 0);
-        exec_.Exec([this, master_key, &recved, &stored](){
-            CHECK(updater_);
-            updater_(master_key, recved, &stored);
-          });
-        server->Response(req_meta);
-        stored.WaitToRead();
       }
     } else {
       // pull
-      if (log_verbose_) LOG(INFO) << "pull: " << master_key;
-      ps::KVPairs<real_t> response;
-      if (num_rows == 0) {
-        std::vector<int> lens(req_data.keys.size(), 0);
-        response.keys = req_data.keys;
-        response.lens.CopyFrom(lens.begin(), lens.end());
-        server->Response(req_meta, response);
-        return;
-      }
-      CHECK(!stored.is_none()) << "init " << master_key << " first";
-      auto shape = stored.shape();
-      auto unit_len = shape.ProdShape(1, shape.ndim());
-      const float* data = stored.data().dptr<float>();
-      auto len = unit_len * num_rows;
-      // concat values
-      response.vals.resize(len);
-      #pragma omp parallel for
-      for (size_t i = 1; i <= num_rows; i++) {
-        int key = DecodeKey(req_data.keys[i]);
-        int64_t row_id = key - master_key;
-        const auto src = data + row_id * unit_len;
-        auto begin = (i - 1) * unit_len;
-        auto end = i * unit_len;
-        response.vals.segment(begin, end).CopyFrom(src, unit_len);
-      }
-      // setup response
-      response.keys = req_data.keys;
-      std::vector<int> lens(req_data.keys.size(), unit_len);
-      lens[0] = 0;
-      response.lens.CopyFrom(lens.begin(), lens.end());
-      server->Response(req_meta, response);
+      RowSparsePullResponse(type, master_key, num_rows, req_meta, req_data, 
server);
     }
   }
 
-  void DefaultStorageResponse(int key, const NDArray& stored,
+  void DefaultStorageResponse(const DataHandleType type,
+                              const int key,
                               const ps::KVMeta& req_meta,
-                              const ps::KVPairs<real_t> &req_data,
-                              ps::KVServer<real_t>* server) {
-    ps::KVPairs<real_t> response;
+                              const ps::KVPairs<char> &req_data,
+                              ps::KVServer<char>* server) {
+    ps::KVPairs<char> response;
+    const NDArray& stored = store_[key];
     CHECK(!stored.is_none()) << "init " << key << " first";
-    auto len = stored.shape().Size();
+
+    // as server returns when store_realt is ready in this case
+    if (has_multi_precision_copy(type)) stored.WaitToRead();
+
+    auto len = stored.shape().Size() * mshadow::mshadow_sizeof(stored.dtype());
     response.keys = req_data.keys;
     response.lens = {len};
     // TODO(mli) try to remove this CopyFrom
-    response.vals.CopyFrom(static_cast<const float*>(stored.data().dptr_), 
len);
+    response.vals.CopyFrom(static_cast<const char*>(stored.data().dptr_), len);
     server->Response(req_meta, response);
   }
 
-  void DataHandleCompressed(const ps::KVMeta& req_meta,
-                            const ps::KVPairs<real_t> &req_data,
-                            ps::KVServer<real_t>* server) {
+  void DataHandleCompressed(const DataHandleType type,
+                            const ps::KVMeta& req_meta,
+                            const ps::KVPairs<char> &req_data,
+                            ps::KVServer<char>* server) {
+    CHECK_EQ(type.dtype, mshadow::kFloat32)
+      << "Gradient compression is currently supported for fp32 only";
     if (req_meta.push) {
       // there used several WaitToRead, this is because \a recved's memory
       // could be deallocated when this function returns. so we need to make 
sure
@@ -403,10 +557,9 @@ class KVStoreDistServer {
       int key = DecodeKey(req_data.keys[1]);
       auto& stored = store_[key];
 
-      size_t ds[] = {(size_t)req_data.lens[1]};
+      size_t ds[] = {(size_t)req_data.lens[1] / 
mshadow::mshadow_sizeof(type.dtype)};
       TShape dshape(ds, ds + 1);
-      TBlob recv_blob((real_t*) req_data.vals.data(), // NOLINT(*)
-                      dshape, cpu::kDevMask);
+      TBlob recv_blob(reinterpret_cast<real_t*>(req_data.vals.data()), dshape, 
cpu::kDevMask);
       NDArray recved = NDArray(recv_blob, 0);
 
       NDArray decomp_buf = decomp_buf_[key];
@@ -423,18 +576,18 @@ class KVStoreDistServer {
         stored.WaitToRead();
       } else if (sync_mode_) {
         // synced push
-        auto& merged = merge_buf_[key];
-        if (merged.array.is_none()) {
-          merged.array = NDArray(dshape, Context());
+        auto& merged = update_buf_[key];
+        if (merged.merged.is_none()) {
+          merged.merged = NDArray(dshape, Context());
         }
         if (merged.request.size() == 0) {
-          gradient_compression_->Dequantize(recved, &merged.array, 0);
+          gradient_compression_->Dequantize(recved, &merged.merged, 0);
         } else {
           gradient_compression_->Dequantize(recved, &decomp_buf, 0);
-          merged.array += decomp_buf;
+          merged.merged += decomp_buf;
         }
         merged.request.push_back(req_meta);
-        ApplyUpdates(key, &merged, &stored, server);
+        ApplyUpdates(type, key, &merged, server);
       } else {
         // async push
         gradient_compression_->Dequantize(recved, &decomp_buf, 0);
@@ -449,63 +602,78 @@ class KVStoreDistServer {
       CHECK_EQ(req_data.keys.size(), (size_t)1);
       CHECK_EQ(req_data.lens.size(), (size_t)0);
       int key = DecodeKey(req_data.keys[0]);
-      DefaultStorageResponse(key, store_[key], req_meta, req_data, server);
+      DefaultStorageResponse(type, key, req_meta, req_data, server);
     }
   }
 
-  void DataHandleDefault(const ps::KVMeta& req_meta,
-                         const ps::KVPairs<real_t> &req_data,
-                         ps::KVServer<real_t>* server) {
-    CHECK_EQ(req_meta.cmd, static_cast<int>(DataHandleType::kDefaultPushPull));
+  void DataHandleDefault(const DataHandleType type, const ps::KVMeta& req_meta,
+                         const ps::KVPairs<char> &req_data,
+                         ps::KVServer<char>* server) {
     // do some check
     CHECK_EQ(req_data.keys.size(), (size_t)1);
     if (req_meta.push) {
       CHECK_EQ(req_data.lens.size(), (size_t)1);
       CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]);
     }
-
     int key = DecodeKey(req_data.keys[0]);
-    auto& stored = store_[key];
-
+    auto& stored = has_multi_precision_copy(type) ? store_realt_[key] : 
store_[key];
     // there used several WaitToRead, this is because \a recved's memory
     // could be deallocated when this function returns. so we need to make sure
     // the operators with \a NDArray are actually finished
     if (req_meta.push) {
-      size_t ds[] = {(size_t)req_data.lens[0]};
+      size_t ds[] = {(size_t) req_data.lens[0] / 
mshadow::mshadow_sizeof(type.dtype)};
       TShape dshape(ds, ds + 1);
-      TBlob recv_blob((real_t*)req_data.vals.data(), // NOLINT(*)
-                      dshape, cpu::kDevMask);
+      TBlob recv_blob;
+      MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
+        recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), 
dshape, cpu::kDevMask);
+      })
       NDArray recved = NDArray(recv_blob, 0);
       if (stored.is_none()) {
         // initialization
-        stored = NDArray(dshape, Context());
+        stored = NDArray(dshape, Context(), false,
+                         has_multi_precision_copy(type) ? mshadow::kFloat32 : 
type.dtype);
         CopyFromTo(recved, &stored, 0);
         server->Response(req_meta);
+        if (has_multi_precision_copy(type)) {
+          auto& stored_dtype = store_[key];
+          stored_dtype = NDArray(dshape, Context(), false, type.dtype);
+          CopyFromTo(stored, stored_dtype);
+          stored_dtype.WaitToRead();
+        }
         stored.WaitToRead();
-      } else if (sync_mode_) {
-        // synced push
-        auto& merged = merge_buf_[key];
-        if (merged.array.is_none()) {
-          merged.array = NDArray(dshape, Context());
+      } else {
+        auto &updates = update_buf_[key];
+        if (sync_mode_ && updates.merged.is_none()) {
+          updates.merged = NDArray(dshape, Context(), false,
+                                   has_multi_precision_copy(type) ? 
mshadow::kFloat32 : type.dtype);
         }
-        if (merged.request.size() == 0) {
-          CopyFromTo(recved, &merged.array, 0);
+        if (has_multi_precision_copy(type) && updates.temp_array.is_none()) {
+          updates.temp_array = NDArray(dshape, Context(), false, 
mshadow::kFloat32);
+        }
+        if (updates.request.empty()) {
+          if (sync_mode_) {
+            CopyFromTo(recved, updates.merged);
+          } else {
+            if (has_multi_precision_copy(type)) {
+              CopyFromTo(recved, updates.temp_array);
+            } else {
+              updates.temp_array = recved;
+            }
+          }
         } else {
-          merged.array += recved;
+          CHECK(sync_mode_);
+          if (has_multi_precision_copy(type)) {
+            CopyFromTo(recved, updates.temp_array);
+            updates.merged += updates.temp_array;
+          } else {
+            updates.merged += recved;
+          }
         }
-        merged.request.push_back(req_meta);
-        ApplyUpdates(key, &merged, &stored, server);
-      } else {
-        // async push
-        exec_.Exec([this, key, &recved, &stored](){
-            CHECK(updater_);
-            updater_(key, recved, &stored);
-          });
-        server->Response(req_meta);
-        stored.WaitToRead();
+        updates.request.push_back(req_meta);
+        ApplyUpdates(type, key, &updates, server);
       }
     } else {
-      DefaultStorageResponse(key, stored, req_meta, req_data, server);
+      DefaultStorageResponse(type, key, req_meta, req_data, server);
     }
   }
 
@@ -526,13 +694,14 @@ class KVStoreDistServer {
    * \brief store_ contains the value at kvstore for each key
    */
   std::unordered_map<int, NDArray> store_;
+  std::unordered_map<int, NDArray> store_realt_;
 
   /**
    * \brief merge_buf_ is a buffer used if sync_mode is true. It represents
    * values from different workers being merged. The store will be updated
    * to this value when values from all workers are pushed into this buffer.
    */
-  std::unordered_map<int, MergeBuf> merge_buf_;
+  std::unordered_map<int, UpdateBuf> update_buf_;
 
   /**
    * \brief decomp_buf_ is a buffer into which compressed values are
@@ -541,11 +710,18 @@ class KVStoreDistServer {
   std::unordered_map<int, NDArray> decomp_buf_;
 
   Executor exec_;
-  ps::KVServer<float>* ps_server_;
+  ps::KVServer<char>* ps_server_;
 
   // whether to LOG verbose information
   bool log_verbose_;
 
+  /*
+   * \brief whether to use multi precision mode.
+   * in multi precision mode, all weights are stored as float32.
+   * any gradient received will be cast to float32 before accumulation and 
updating of weights.
+   */
+  bool multi_precision_;
+
   /**
    * \brief gradient compression object.
    * starts with none, used after SetGradientCompression sets the type
diff --git a/tests/nightly/dist_sync_kvstore.py 
b/tests/nightly/dist_sync_kvstore.py
index 3a3c916..3bf5cbf 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -20,132 +20,169 @@
 # pylint: skip-file
 import sys
 sys.path.insert(0, "../../python/")
+import argparse
 import mxnet as mx
 import numpy as np
 import numpy.random as rnd
 from mxnet.test_utils import assert_almost_equal
 from test_kvstore import compute_expected_2bit_quantization
 
-def check_diff_to_scalar(A, x, rank=None):
-    """ assert A == x"""
-    assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x)
+def check_diff(A, x, rank=None):
+    """ assert A == x
+        x can be scalar as well as numpy array
+    """
+    assert (np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), 
x.asnumpy())
 
 # setup
-keys = ['3', '5', '7']
-rsp_keys = ['9', '11', '13']
-init_test_keys = [str(i) for i in range(200,300)]
-init_test_keys_big = [str(i) for i in range(300,400)]
-init_test_keys_device = [str(i) for i in range(400,500)]
-init_test_keys_device_big = [str(i) for i in range(500,600)]
-
-rate = 2
 shape = (2, 3)
 irregular_shape = (1211,1211)
 big_shape = (1200, 1200)        # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
 
+keys_shape = ['3', '5', '7']
+keys_big_shape = ['99']
+fp16_keys_shape = ['4', '6', '8']
+fp16_keys_big_shape = ['100']
+
+rsp_keys_shape = ['9', '11', '13']
+rsp_keys_big_shape = ['97']
+fp16_rsp_keys_shape = ['10', '12', '14']
+fp16_rsp_keys_big_shape = ['98']
+
+keys_shapes = [(k, shape) for k in keys_shape] + [(k, big_shape) for k in 
keys_big_shape]
+fp16_keys_shapes = [(k, shape) for k in fp16_keys_shape] + [(k, big_shape) for 
k in fp16_keys_big_shape]
+
+init_test_keys = [str(i) for i in range(200, 300)]
+init_test_keys_big = [str(i) for i in range(300, 400)]
+init_test_keys_device = [str(i) for i in range(400, 500)]
+init_test_keys_device_big = [str(i) for i in range(500, 600)]
+
+compr_keys_shapes = [('1000', shape), ('1200', irregular_shape),('1300', 
big_shape)]
+compr_init_keys_shapes = [('1001', shape), ('1201', irregular_shape),('1301', 
big_shape)]
+compr_random_keys_shapes = [('1002', shape),('1202', irregular_shape),('1302', 
big_shape)]
+
+rate = 2
+
 kv = mx.kv.create('dist_sync')
 
+my_rank = kv.rank
+nworker = kv.num_workers
+
 def init_kv():
-    # init kv dns keys
-    kv.init(keys, [mx.nd.ones(shape)] * len(keys))
-    kv.init('99', mx.nd.ones(big_shape))
-    # init kv row_sparse keys
-    kv.init(rsp_keys, [mx.nd.ones(shape).tostype('row_sparse')] * 
len(rsp_keys))
-    kv.init('100', mx.nd.ones(big_shape).tostype('row_sparse'))
-    # worker info
-    my_rank = kv.rank
-    nworker = kv.num_workers
+    # # init kv dns keys
+    kv.init(keys_shape, [mx.nd.ones(shape)] * len(keys_shape))
+    kv.init(keys_big_shape, [mx.nd.ones(big_shape)] * len(keys_big_shape))
+    # # init kv row_sparse keys
+    kv.init(rsp_keys_shape, [mx.nd.ones(shape).tostype('row_sparse')] * 
len(rsp_keys_shape))
+    kv.init(rsp_keys_big_shape, [mx.nd.ones(big_shape).tostype('row_sparse')] 
* len(rsp_keys_big_shape))
+    # init fp16 dns keys
+    kv.init(fp16_keys_shape, [mx.nd.ones(shape, dtype='float16')] * 
len(keys_shape))
+    kv.init(fp16_keys_big_shape, [mx.nd.ones(big_shape, dtype='float16')] * 
len(keys_big_shape))
+    # init fp16 row_sparse keys
+    kv.init(fp16_rsp_keys_shape, [mx.nd.ones(shape, 
dtype='float16').tostype('row_sparse')] * len(fp16_rsp_keys_shape))
+    kv.init(fp16_rsp_keys_big_shape, [mx.nd.ones(big_shape, 
dtype='float16').tostype('row_sparse')] * len(fp16_rsp_keys_big_shape))
+    return kv
+
+def set_optimizer(use_multiprecision):
     # init updater on servers
-    kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate))
-    return kv, my_rank, nworker
+    kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate, 
multi_precision=use_multiprecision))
+    return kv
 
 def init_kv_compressed(kv):
     threshold = 0.5
-    kv.set_gradient_compression({'type': '2bit', 'threshold':threshold})
+    kv.set_gradient_compression({'type': '2bit', 'threshold': threshold})
     # init kv compression keys
-    kv.init('11221', mx.nd.zeros(big_shape))
-    kv.init('112221', mx.nd.zeros(irregular_shape))
-    kv.init('1121', mx.nd.zeros(shape))
+    for k, s in compr_keys_shapes:
+        kv.init(k, mx.nd.zeros(s))
     # to test inactive mode
-    kv.init('1122', mx.nd.ones(shape))
+    for k, s in compr_init_keys_shapes:
+        kv.init(k, mx.nd.ones(s))
     return kv, threshold
 
-def test_sync_push_pull():
-    kv, my_rank, nworker = init_kv()
-    def check_default_keys(kv, my_rank, nworker):
-        nrepeat = 3
+def test_sync_push_pull(nrepeat):
+    def check_default_keys(dtype, nrepeat):
         # checks pull after push in loop, because behavior during
         # consecutive pushes doesn't offer any guarantees
-        for i in range(nrepeat):
-            kv.push('3', mx.nd.ones(shape)*(my_rank+1))
-            kv.push('99', mx.nd.ones(big_shape)*(my_rank+1))
-            num = (nworker + 1) * nworker * rate / 2 * (i + 1) + 1
-            val = mx.nd.zeros(shape)
-            kv.pull('3', out=val)
-            check_diff_to_scalar(val, num)
-            val2 = mx.nd.zeros(big_shape)
-            kv.pull('99', out=val2)
-            check_diff_to_scalar(val2, num)
-
-    def check_row_sparse_keys(kv, my_rank, nworker):
-        nrepeat = 3
+        ks = keys_shapes if dtype == 'float32' else fp16_keys_shapes
+        for k, s in ks:
+            for i in range(nrepeat):
+                kv.push(k, mx.nd.ones(s, dtype=dtype)*(my_rank+1))
+                num = (nworker + 1) * nworker * rate / 2 * (i + 1) + 1
+                val = mx.nd.zeros(s, dtype=dtype)
+                kv.pull(k, out=val)
+                check_diff(val, num)
+
+    def check_row_sparse_keys(dtype, nrepeat):
         # prepare gradient
-        v = mx.nd.zeros(shape)
+        v = mx.nd.zeros(shape, dtype=dtype)
         my_row = my_rank % shape[0]
         v[my_row] = my_rank + 1
         # push
+        if dtype == 'float32':
+            k = rsp_keys_shape[0]
+        else:
+            k = fp16_rsp_keys_shape[0]
+        s = shape
         for i in range(nrepeat):
-            kv.push('9', v.tostype('row_sparse'))
+            kv.push(k, v.tostype('row_sparse'))
             # select a random subset of rows this worker is interested in
-            num_rows = shape[0]
+            num_rows = s[0]
             row_ids_np = np.random.randint(num_rows, size=num_rows)
-            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2))
+            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 
2)).astype(dtype)
             # perform pull
-            val = mx.nd.zeros(shape, stype='row_sparse')
-            kv.row_sparse_pull('9', out=val, row_ids=row_ids)
+            val = mx.nd.zeros(s, stype='row_sparse', dtype=dtype)
+            kv.row_sparse_pull(k, out=val, row_ids=row_ids)
             # prepare updated values
-            updated_val = mx.nd.ones(shape)
+            updated_val = mx.nd.ones(s, dtype=dtype)
             for rank in range(nworker):
-                row = rank % shape[0]
+                row = rank % s[0]
                 updated_val[row] += (rank + 1) * rate * (i+1)
             # verify subset of updated values
-            expected = mx.nd.zeros(shape)
+            expected = mx.nd.zeros(s, dtype=dtype)
             for row in row_ids_np:
                 expected[row] = updated_val[row]
-            check_diff_to_scalar(val, expected)
+            check_diff(val, expected, kv.rank)
 
-    def check_row_sparse_keys_with_zeros(kv, my_rank, nworker):
-        nrepeat = 3
+    def check_row_sparse_keys_with_zeros(dtype, nrepeat):
+        if dtype == 'float32':
+            k1 = rsp_keys_shape[1]
+            k2 = rsp_keys_big_shape[0]
+        else:
+            k1 = fp16_rsp_keys_shape[1]
+            k2 = fp16_rsp_keys_big_shape[0]
         # prepare gradient
-        v = mx.nd.sparse.zeros('row_sparse', shape)
-        big_v = mx.nd.sparse.zeros('row_sparse', big_shape)
+        v = mx.nd.sparse.zeros('row_sparse', shape, dtype=dtype)
+        big_v = mx.nd.sparse.zeros('row_sparse', big_shape, dtype=dtype)
         # push
         for i in range(nrepeat):
-            kv.push('11', v)
-            kv.push('100', big_v)
+            kv.push(k1, v)
+            kv.push(k2, big_v)
             # pull a subset of rows this worker is interested in
             all_row_ids = np.arange(shape[0])
             val = mx.nd.sparse.zeros('row_sparse', shape)
             big_val = mx.nd.sparse.zeros('row_sparse', big_shape)
-            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids))
+            kv.row_sparse_pull(k1, out=val, row_ids=mx.nd.array(all_row_ids))
             big_all_row_ids = np.arange(big_shape[0])
-            kv.row_sparse_pull('100', out=big_val, 
row_ids=mx.nd.array(big_all_row_ids))
+            kv.row_sparse_pull(k2, out=big_val, 
row_ids=mx.nd.array(big_all_row_ids))
             # verify results
-            check_diff_to_scalar(val, 1)
-            check_diff_to_scalar(big_val, 1)
+            check_diff(val, 1)
+            check_diff(big_val, 1)
             # pull empty weights
-            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array([]))
-            kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array([]))
-            check_diff_to_scalar(val, 0)
-            check_diff_to_scalar(big_val, 0)
+            kv.row_sparse_pull(k1, out=val, row_ids=mx.nd.array([]))
+            kv.row_sparse_pull(k2, out=big_val, row_ids=mx.nd.array([]))
+            check_diff(val, 0)
+            check_diff(big_val, 0)
+
+    def check_big_row_sparse_keys(dtype, nrepeat):
+        if dtype == 'float32':
+            k = rsp_keys_big_shape[0]
+        else:
+            k = fp16_rsp_keys_big_shape[0]
 
-    def check_big_row_sparse_keys(kv, my_rank, nworker):
         mx.random.seed(123)
         rnd.seed(123)
         density = 0.3
-        nrepeat = 3
         # prepare gradient
-        v = mx.nd.zeros(big_shape)
+        v = mx.nd.zeros(big_shape, dtype=dtype)
         idx_sample = rnd.rand(big_shape[0])
         indices = np.argwhere(idx_sample < density).flatten()
         # each worker chooses a subset of the indices to update
@@ -163,98 +200,102 @@ def test_sync_push_pull():
             v[row] = my_rank + 1
         # push
         for i in range(nrepeat):
-            kv.push('100', v.tostype('row_sparse'))
-
+            kv.push(k, v.tostype('row_sparse'))
             # select a random subset of rows this worker is interested in
             mx.random.seed(my_rank)
             rnd.seed(my_rank)
             num_rows = big_shape[0]
             row_ids_np = np.random.randint(num_rows, size=num_rows)
-            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2))
+            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 
2)).astype(dtype)
             # perform pull
-            val = mx.nd.zeros(big_shape, stype='row_sparse')
-            kv.row_sparse_pull('100', out=val, row_ids=row_ids)
+            val = mx.nd.zeros(big_shape, stype='row_sparse', dtype=dtype)
+            kv.row_sparse_pull(k, out=val, row_ids=row_ids)
             # prepare expected result
-            updated_val = mx.nd.ones(big_shape)
+            updated_val = mx.nd.ones(big_shape, dtype=dtype)
             # apply updates from each worker
             for rank in range(nworker):
                 for row in update_rows[rank]:
                     updated_val[row] += (rank + 1) * rate * (i+1)
 
-            expected = mx.nd.zeros(big_shape)
+            expected = mx.nd.zeros(big_shape, dtype=dtype)
             for row in row_ids_np:
                 expected[row] = updated_val[row]
-            check_diff_to_scalar(val, expected, rank=my_rank)
+            check_diff(val, expected.astype(dtype), rank=my_rank)
 
-    def check_compr_residual(kv, threshold, nworker):
-        for k,s in [('1121', shape),('112221',irregular_shape),('11221', 
big_shape)]:
+    for dtype in ['float16', 'float32']:
+        check_default_keys(dtype, nrepeat)
+        check_row_sparse_keys(dtype, nrepeat)
+        check_row_sparse_keys_with_zeros(dtype, nrepeat)
+        check_big_row_sparse_keys(dtype, nrepeat)
+    print('worker ' + str(my_rank) + ' is done with non compression tests')
+
+def test_sync_2bit_compression(threshold, nrepeat):
+    def check_compr_residual(threshold):
+        for k, s in compr_keys_shapes:
             # doesn't meet threshold
-            kv.push(k, mx.nd.ones(s)*0.4)
-            val=mx.nd.zeros(s)
+            kv.push(k, mx.nd.ones(s) * 0.4)
+            val = mx.nd.zeros(s)
             kv.pull(k,val)
-            check_diff_to_scalar(val, 0)
+            check_diff(val, 0)
 
             # just meets threshold with residual
-            kv.push(k, mx.nd.ones(s)*(threshold - 0.4))
+            kv.push(k, mx.nd.ones(s) * (threshold - 0.4))
             val2 = mx.nd.zeros(s)
             kv.pull(k,val2)
             curval = threshold * rate * nworker
-            check_diff_to_scalar(val2, curval)
+            check_diff(val2, curval)
 
             # doesn't meet threshold
-            kv.push(k, mx.nd.ones(s)*0.2)
-            val3= mx.nd.zeros(s)
+            kv.push(k, mx.nd.ones(s) * 0.2)
+            val3 = mx.nd.zeros(s)
             kv.pull(k, val3)
-            check_diff_to_scalar(val3, curval)
+            check_diff(val3, curval)
 
             # exceeds again
-            kv.push(k, mx.nd.ones(s)*(threshold-0.2))
+            kv.push(k, mx.nd.ones(s) * (threshold-0.2))
             val4 = mx.nd.zeros(s)
-            kv.pull(k,val4)
-            curval += threshold*rate*nworker
-            check_diff_to_scalar(val4, curval)
+            kv.pull(k, val4)
+            curval += threshold * rate * nworker
+            check_diff(val4, curval)
             # residual is 0 now
 
-    def check_compr_ones(kv, threshold, nworker):
-        for k,s in [('1121', shape),('112221',irregular_shape),('11221', 
big_shape)]:
+    def check_compr_ones(threshold):
+        for k, s in compr_keys_shapes:
             val = mx.nd.zeros(s)
             kv.pull(k, val)
             curval = val[0][0].asnumpy()[0]
-            kv.push(k,mx.nd.ones(s)*threshold)
+            kv.push(k,mx.nd.ones(s) * threshold)
             val2 = mx.nd.zeros(s)
             kv.pull(k, val2)
-            newval = curval + rate*nworker*threshold
-            check_diff_to_scalar(val2, newval)
+            newval = curval + rate * nworker * threshold
+            check_diff(val2, newval)
             # residual = 0  again
 
-    def check_compr_pull_before_push(kv):
-        for k,s in [('1121', shape),('112221',irregular_shape),
-                    ('11221', big_shape), ('1122',shape)]:
-            if k=='1122':
-                # tests that GC is not used for init of a key
-                val = mx.nd.zeros(s)
-                kv.pull(k, val)
-                check_diff_to_scalar(val, 1)
-            else:
-                val = mx.nd.ones(s)
-                kv.pull(k, val)
-                check_diff_to_scalar(val, 0)
+    def check_compr_pull_before_push():
+        for k,s in compr_keys_shapes:
+            val = mx.nd.ones(s)
+            kv.pull(k, val)
+            check_diff(val, 0)
+        for k, s in compr_init_keys_shapes:
+            # tests that GC is not used for init of a key
+            val = mx.nd.zeros(s)
+            kv.pull(k, val)
+            check_diff(val, 1)
 
-    def check_compr_zero(kv):
-        for k,s in [('1121', shape),('112221',irregular_shape),('11221', 
big_shape)]:
+    def check_compr_zero():
+        for k,s in compr_keys_shapes:
             kv.push(k, mx.nd.zeros(s))
             # to check that all are set to 0s
             val = mx.nd.ones(s)
             kv.pull(k, val)
-            check_diff_to_scalar(val, 0)
+            check_diff(val, 0)
 
-    def check_compr_random(kv, threshold, nworker):
+    def check_compr_random(threshold, nrepeat):
         # set a seed so all workers generate same data. knowing this helps
         # calculate expected value after pull
         mx.random.seed(123)
         rnd.seed(123)
-        nrepeat = 5
-        compr_random_keys_shapes = [('2121', 
shape),('212221',irregular_shape),('21221', big_shape)]
+
         # use new keys so residual is 0 for calculation of expected
         for k,s in compr_random_keys_shapes:
             kv.init(k, mx.nd.zeros(s))
@@ -278,39 +319,51 @@ def test_sync_push_pull():
                 decompr *= nworker * rate
                 assert_almost_equal(diff.asnumpy(), decompr)
 
-    print ('worker '+str(my_rank)+' started with non compression tests')
-    check_default_keys(kv, my_rank, nworker)
-    check_row_sparse_keys(kv, my_rank, nworker)
-    check_row_sparse_keys_with_zeros(kv, my_rank, nworker)
-    check_big_row_sparse_keys(kv, my_rank, nworker)
-    print('worker ' + str(my_rank) + ' is done with non compression tests')
-
-    # don't run non compressed keys after this as kvstore now is set to 
compressed
-    print ('worker '+str(my_rank)+' started with compression tests')
-    kv, threshold = init_kv_compressed(kv)
-    check_compr_pull_before_push(kv)
-    check_compr_zero(kv)
-    check_compr_residual(kv, threshold, nworker)
-    check_compr_ones(kv, threshold, nworker)
-    check_compr_random(kv, threshold, nworker)
+    print ('worker ' + str(my_rank) + ' started with compression tests')
+    check_compr_pull_before_push()
+    check_compr_zero()
+    check_compr_residual(threshold)
+    check_compr_ones(threshold)
+    check_compr_random(threshold, nrepeat)
     print('worker ' + str(my_rank) + ' is done with compression tests')
 
-def test_sync_init():
+def test_sync_init(gpu_tests=False):
+    def get_dtype(idx, cur_keys):
+        if idx < len(cur_keys)/2:
+            dtype = 'float32'
+        else:
+            dtype = 'float16'
+        return dtype
+
     def check_init(kv, cur_keys, cur_shape, device=False):
         ctx = mx.gpu(0) if device else mx.cpu()
-        val = [mx.nd.zeros(cur_shape, ctx) for i in cur_keys]
+        val = [mx.nd.zeros(cur_shape, ctx=ctx, dtype=get_dtype(i, cur_keys)) 
for i in range(len(cur_keys))]
         for i in range(len(cur_keys)):
             expected = i
-            kv.init(cur_keys[i], [mx.nd.ones(cur_shape, ctx) * i])
+            kv.init(cur_keys[i], [mx.nd.ones(cur_shape, ctx=ctx, 
dtype=get_dtype(i, cur_keys)) * i])
             kv.pull(cur_keys[i], out=val[i])
-            check_diff_to_scalar(val[i], expected)
+            check_diff(val[i], expected)
     check_init(kv, init_test_keys, shape)
     check_init(kv, init_test_keys_big, big_shape)
-    check_init(kv, init_test_keys_device, shape, device=True)
-    check_init(kv, init_test_keys_device_big, big_shape, device=True)
-    my_rank = kv.rank
-    print('worker ' + str(my_rank) + ' is initialized')
+    if gpu_tests:
+        check_init(kv, init_test_keys_device, shape, device=True)
+        check_init(kv, init_test_keys_device_big, big_shape, device=True)
+    print('worker ' + str(kv.rank) + ' is initialized')
 
 if __name__ == "__main__":
-    test_sync_init()
-    test_sync_push_pull()
+    parser = argparse.ArgumentParser(description='test distributed kvstore in 
dist_sync mode')
+    parser.add_argument('--nrepeat', type=int, default=7)
+    parser.add_argument('--type', type=str, default='all')
+    parser.add_argument('--no-gpu', dest='gpu', action='store_false')
+    parser.add_argument('--no-multiprecision', dest='multiprecision', 
action='store_false')
+    opt = parser.parse_args()
+    if opt.type == 'all' or  opt.type == 'init':
+        test_sync_init(opt.gpu)
+    kv = init_kv()
+    if opt.type == 'all' or  opt.type == 'default':
+        kv = set_optimizer(use_multiprecision=opt.multiprecision)
+        test_sync_push_pull(opt.nrepeat)
+    # dont run non compressed tests after this as kvstore compression will be 
set here
+    if opt.type == 'all' or  opt.type == 'compressed':
+        kv, threshold = init_kv_compressed(kv)
+        test_sync_2bit_compression(threshold, opt.nrepeat)

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to