piiswrong closed pull request #10183: [MXNET-120] Float16 support for 
distributed training
URL: https://github.com/apache/incubator-mxnet/pull/10183
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/3rdparty/mshadow b/3rdparty/mshadow
index f5b67f380cb..0b4cedd7015 160000
--- a/3rdparty/mshadow
+++ b/3rdparty/mshadow
@@ -1 +1 @@
-Subproject commit f5b67f380cb0588be11e6f440f92f013139380ee
+Subproject commit 0b4cedd7015cc69191f8338a8feaacda90697758
diff --git a/CODEOWNERS b/CODEOWNERS
index 3660e382ebe..1ea9b567c4a 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -17,9 +17,12 @@
 /perl-package/    @sergeykolychev
 /python/          @szha
 
+# C++ base
+/src/kvstore/     @rahul003
+
 # CMake
-CMakeLists.txt    @szha
-/cmake/           @szha
+CMakeLists.txt    @szha @rahul003
+/cmake/           @szha @rahul003
 
 # MXNet CI
 /tests/ci_build/    @marcoabreu
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index a32e33ee1ab..2ba0721533c 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -163,4 +163,5 @@ List of Contributors
 * [David Braude](https://github.com/dabraude/)
 * [Nick Robinson](https://github.com/nickrobinson)
 * [Kan Wu](https://github.com/wkcn)
+* [Rahul Huilgol](https://github.com/rahul003)
 * [Anirudh Subramanian](https://github.com/anirudh2290/)
diff --git a/Jenkinsfile b/Jenkinsfile
index 3892906b45a..8686012164d 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -671,6 +671,17 @@ try {
           }
         }
       }
+    },
+    'dist-kvstore tests GPU': {
+      node('mxnetlinux-gpu') {
+        ws('workspace/it-dist-kvstore') {
+          init_git()
+          unpack_lib('gpu')
+          timeout(time: max_time, unit: 'MINUTES') {
+            sh "ci/build.py --nvidiadocker --platform ubuntu_gpu 
/work/runtime_functions.sh integrationtest_ubuntu_gpu_dist_kvstore"
+          }
+        }
+      }
     }
   }
 
diff --git a/amalgamation/Makefile b/amalgamation/Makefile
index 9c45885b7cf..f7f3c001e19 100644
--- a/amalgamation/Makefile
+++ b/amalgamation/Makefile
@@ -51,6 +51,15 @@ endif
 DEFS+=-DMSHADOW_USE_CUDA=0 -DMSHADOW_USE_MKL=0 -DMSHADOW_RABIT_PS=0 
-DMSHADOW_DIST_PS=0 -DDMLC_LOG_STACK_TRACE=0
 DEFS+=-DMSHADOW_FORCE_STREAM -DMXNET_USE_OPENCV=0 -DMXNET_PREDICT_ONLY=1
 CFLAGS=-std=c++11 -Wno-unknown-pragmas -Wall $(DEFS)
+
+# if architecture of the CPU supports F16C instruction set, enable USE_F16C 
for fast fp16 computation on CPU
+ifeq ($(USE_F16C), 1)
+       CFLAGS+=-mf16c
+       DEFS+=-DMSHADOW_USE_F16C=1
+else
+       DEFS+=-DMSHADOW_USE_F16C=0
+endif
+
 ifneq ($(MIN), 1)
        CFLAGS += -I${OPENBLAS_ROOT} -I${OPENBLAS_ROOT}/include
        LDFLAGS+= -L${OPENBLAS_ROOT} -L${OPENBLAS_ROOT}/lib
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 484b49c8c0e..fcb6424cce1 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -495,6 +495,15 @@ integrationtest_ubuntu_gpu_cpp_package() {
     cpp-package/tests/ci_test.sh
 }
 
+integrationtest_ubuntu_gpu_dist_kvstore() {
+    set -ex
+    export PYTHONPATH=./python/
+    export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    cd tests/nightly/
+    ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py
+    ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py 
--no-multiprecision
+    ../../tools/launch.py -n 7 --launcher local python 
dist_device_sync_kvstore.py
+}
 
 test_ubuntu_cpu_python2() {
     set -ex
diff --git a/make/crosscompile.jetson.mk b/make/crosscompile.jetson.mk
index 9ca4109fa0e..31a1398c1b7 100644
--- a/make/crosscompile.jetson.mk
+++ b/make/crosscompile.jetson.mk
@@ -132,7 +132,7 @@ endif
 # Settings for power and arm arch
 #----------------------------
 USE_SSE=0
-
+USE_F16C=0
 #----------------------------
 # distributed computing
 #----------------------------
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 88b2e880d38..e730fd7cd8b 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -114,12 +114,13 @@ def _init_kvstore(self):
                 kvstore.set_gradient_compression(self._compression_params)
             if 'dist' in kvstore.type:
                 update_on_kvstore = False
+            if update_on_kvstore:
+                kvstore.set_optimizer(self._optimizer)
+            # optimizer preferably needs to be set before init for 
multiprecision
             for i, param in enumerate(self._params):
                 param_arrays = param.list_data()
                 kvstore.init(i, param_arrays[0])
                 kvstore.pull(i, param_arrays, priority=-i)
-            if update_on_kvstore:
-                kvstore.set_optimizer(self._optimizer)
             self._kvstore = kvstore
             self._update_on_kvstore = update_on_kvstore
         else:
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 5520597530e..f31dac01cd1 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -83,6 +83,14 @@ def updater_handle(key, lhs_handle, rhs_handle, _):
         updater(key, lhs, rhs)
     return updater_handle
 
+def _get_kvstore_server_command_type(command):
+    command_types = {'kController': 0,
+                     'kSetMultiPrecision': 1,
+                     'kStopServer': 2,
+                     'kSyncMode': 3,
+                     'kSetGradientCompression': 4}
+    assert (command in command_types), "Unknown command type to send to server"
+    return command_types[command]
 
 class KVStore(object):
     """A key-value store for synchronization of values, over multiple 
devices."""
@@ -473,7 +481,11 @@ def set_optimizer(self, optimizer):
                 optim_str = py_str(pickle.dumps(optimizer, 0))
             except:
                 raise
-            self._send_command_to_servers(0, optim_str)
+            cmd = _get_kvstore_server_command_type('kController')
+            self._send_command_to_servers(cmd, optim_str)
+            if optimizer.multi_precision:
+                cmd = _get_kvstore_server_command_type('kSetMultiPrecision')
+                self._send_command_to_servers(cmd, '')
         else:
             self._set_updater(opt.get_updater(optimizer))
 
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index 26e885a1cd8..ae7726d76a7 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -253,6 +253,8 @@ def _train_multi_device(symbol, ctx, arg_names, 
param_names, aux_names,
 
     if not update_on_kvstore:
         updater = get_updater(optimizer)
+    else:
+        kvstore.set_optimizer(optimizer)
 
     if kvstore:
         _initialize_kvstore(kvstore=kvstore,
@@ -261,9 +263,6 @@ def _train_multi_device(symbol, ctx, arg_names, 
param_names, aux_names,
                             param_names=executor_manager.param_names,
                             update_on_kvstore=update_on_kvstore)
 
-    if update_on_kvstore:
-        kvstore.set_optimizer(optimizer)
-
     # Now start training
     train_data.reset()
     for epoch in range(begin_epoch, end_epoch):
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index 21d9b568e37..a05c3a31cd2 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -536,15 +536,16 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
         if kvstore:
             if self._compression_params:
                 kvstore.set_gradient_compression(self._compression_params)
+            if update_on_kvstore:
+                kvstore.set_optimizer(self._optimizer)
             # copy initialized local parameters to kvstore
             _initialize_kvstore(kvstore=kvstore,
                                 param_arrays=self._exec_group.param_arrays,
                                 arg_params=self._arg_params,
                                 param_names=self._param_names,
                                 update_on_kvstore=update_on_kvstore)
-        if update_on_kvstore:
-            kvstore.set_optimizer(self._optimizer)
-        else:
+
+        if not update_on_kvstore:
             self._updater = opt.get_updater(optimizer)
 
         self.optimizer_initialized = True
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index afba9ac5f27..373081bc7b1 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -47,7 +47,7 @@ class KVStoreDist : public KVStoreLocal {
       : KVStoreLocal(use_device_comm), ps_worker_(nullptr), server_(nullptr) {
     if (IsWorkerNode()) {
       int new_customer_id = GetNewCustomerId();
-      ps_worker_ = new ps::KVWorker<real_t>(0, new_customer_id);
+      ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
       ps::StartAsync(new_customer_id, "mxnet\0");
       if (!ps::Postoffice::Get()->is_recovery()) {
         ps::Postoffice::Get()->Barrier(
@@ -228,17 +228,18 @@ class KVStoreDist : public KVStoreLocal {
           RunContext rctx, Engine::CallbackOnComplete cb) {
         // convert to ps keys
         size_t size = recv_buf.shape().Size();
-
+        const int dtype = recv_buf.dtype();
+        const int num_bytes = mshadow::mshadow_sizeof(dtype);
         PSKV& pskv = (gradient_compression_->get_type() == 
CompressionType::kNone) ?
-                      EncodeDefaultKey(key, size, false) :
-                      EncodeCompressedKey(key, size, false);
-        real_t* data = recv_buf.data().dptr<real_t>();
+                      EncodeDefaultKey(key, size, num_bytes) :
+                      EncodeCompressedKey(key, size, false, num_bytes);
+        char* data = static_cast<char*> (recv_buf.data().dptr_);
         // false means not to delete data when SArray is deleted
-        auto vals = new ps::SArray<real_t>(data, size, false);
+        auto vals = new ps::SArray<char>(data, size * num_bytes, false);
         // issue pull
-        int cmd = (gradient_compression_->get_type() != 
CompressionType::kNone) ?
-                  static_cast<int>(DataHandleType::kCompressedPushPull) :
-                  static_cast<int>(DataHandleType::kDefaultPushPull);
+        RequestType mode = (gradient_compression_->get_type() != 
CompressionType::kNone) ?
+                  RequestType::kCompressedPushPull : 
RequestType::kDefaultPushPull;
+        const int cmd = GetCommandType(mode, dtype);
         CHECK_NOTNULL(ps_worker_)->ZPull(
           pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); 
});
       };
@@ -329,18 +330,21 @@ class KVStoreDist : public KVStoreLocal {
         }
         CopyFromTo(merged, &comm_buf);
       }
-
+      const int dtype = merged.dtype();
+      const int num_bytes = mshadow::mshadow_sizeof(dtype);
       // push to servers
       if (storage_type == kDefaultStorage) {
         if (gradient_compression_->get_type() == CompressionType::kNone) {
-          PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), true);
+          PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), 
num_bytes);
           PushDefault(key, comm_buf, pskv, priority);
         } else {
+          CHECK_EQ(dtype, mshadow::kFloat32) << "Gradient compression is only 
supported for "
+                                             << "float32 type of parameters";
           // Note: gradient compression uses `do_merge` as proxy to
           // detect whether the push is initialization of a key or not.
           // is_active is false when push is initialization of key
           bool is_active = do_merge;
-          PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), 
is_active);
+          PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), 
is_active, num_bytes);
           // Returns push_pskv if active, else pull_pskv
           // we want inactive gc to send uncompressed gradients,
           // but sharded in the same way as later pushes would when gc becomes 
active
@@ -363,25 +367,24 @@ class KVStoreDist : public KVStoreLocal {
   void PushCompressed(int key, const NDArray& comm_buf, const PSKV& pskv, int 
priority) {
     auto &small_buf = compr_buf_[key];
     auto &res_buf = residual_[key];
-    size_t original_size = comm_buf.shape().Size();
+    const size_t original_size = comm_buf.shape().Size();
+    const int dtype = comm_buf.dtype();
 
     // Init the small buffer and residual_ buffer for quantize
     if (small_buf.is_none()) {
-      small_buf = NDArray(TShape{pskv.size}, comm_buf.ctx(), false, 
comm_buf.dtype());
-      res_buf = NDArray(TShape{(int64_t) original_size}, comm_buf.ctx(),
-                        false, comm_buf.dtype());
+      small_buf = NDArray(TShape{pskv.size}, comm_buf.ctx(), false, dtype);
+      res_buf = NDArray(TShape{static_cast<int64_t>(original_size)}, 
comm_buf.ctx(), false, dtype);
       res_buf = 0;
     }
     gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority);
     auto push_to_servers =
-      [this, key, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete 
cb) {
-        size_t size = small_buf.shape().Size();
-        real_t* data = small_buf.data().dptr<real_t>();
+      [this, key, dtype, pskv, small_buf](RunContext rctx, 
Engine::CallbackOnComplete cb) {
+        size_t size = small_buf.shape().Size() * 
mshadow::mshadow_sizeof(dtype);
+        char* data = static_cast<char *> (small_buf.data().dptr_);
         // do push. false means no delete
-        ps::SArray<real_t> vals(data, size, false);
-        CHECK_NOTNULL(ps_worker_)->ZPush(
-          pskv.keys, vals, pskv.lens,
-          static_cast<int>(DataHandleType::kCompressedPushPull), [cb]() { 
cb(); });
+        ps::SArray<char> vals(data, size, false);
+        int cmd = GetCommandType(RequestType::kCompressedPushPull, dtype);
+        CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, 
[cb]() { cb(); });
       };
     // acquire locks on both comm_buf and small_buf so that
     // pull (which uses comm_buf) for the same key waits till push finishes
@@ -398,14 +401,16 @@ class KVStoreDist : public KVStoreLocal {
   void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int 
priority) {
     auto push_to_servers =
         [this, key, pskv, send_buf](RunContext rctx, 
Engine::CallbackOnComplete cb) {
+          const int dtype = send_buf.dtype();
           // convert to ps keys
-          size_t size = send_buf.shape().Size();
-          real_t* data = send_buf.data().dptr<real_t>();
+          const size_t size = send_buf.shape().Size() * 
mshadow::mshadow_sizeof(dtype);
+          char* data = static_cast<char *>(send_buf.data().dptr_);
           // do push. false means no delete
-          ps::SArray<real_t> vals(data, size, false);
+          ps::SArray<char> vals(data, size, false);
+          int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
           CHECK_NOTNULL(ps_worker_)->ZPush(
               pskv.keys, vals, pskv.lens,
-              static_cast<int>(DataHandleType::kDefaultPushPull), [cb]() { 
cb(); });
+              cmd, [cb]() { cb(); });
         };
     Engine::Get()->PushAsync(
         push_to_servers,
@@ -422,23 +427,22 @@ class KVStoreDist : public KVStoreLocal {
     using namespace rowsparse;
     auto push_to_servers = [this, key, send_buf]
                            (RunContext rctx, Engine::CallbackOnComplete cb) {
-      real_t* data = send_buf.data().dptr<real_t>();
+      char* data = static_cast<char *>(send_buf.data().dptr_);
       const int64_t num_rows = send_buf.aux_shape(kIdx)[0];
       const auto offsets = send_buf.aux_data(kIdx).dptr<int64_t>();
       const auto unit_len = send_buf.shape().ProdShape(1, 
send_buf.shape().ndim());
+      const int num_bytes = mshadow::mshadow_sizeof(send_buf.dtype());
       const int64_t size = num_rows * unit_len;
-
        // convert to ps keys in row sparse format
       PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
-                                      unit_len, send_buf.shape()[0]);
+                                      unit_len, send_buf.shape()[0], 
num_bytes);
       if (this->log_verbose_) {
         LOG(INFO) << "worker " << get_rank() << " push lens: " << pskv.lens << 
" keys: "
                   << pskv.keys << " size: " << size;
       }
-      ps::SArray<real_t> vals(data, size, false);
-      CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens,
-                                       
static_cast<int>(DataHandleType::kRowSparsePushPull),
-                                       [cb]() { cb(); });
+      ps::SArray<char> vals(data, size * num_bytes, false);
+      const int cmd = GetCommandType(RequestType::kRowSparsePushPull, 
send_buf.dtype());
+      CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() 
{ cb(); });
     };
     Engine::Get()->PushAsync(
         push_to_servers,
@@ -460,27 +464,31 @@ class KVStoreDist : public KVStoreLocal {
       // allocate memory for the buffer
       CHECK_EQ(indices.dtype(), mshadow::kInt64);
       const TBlob idx_data = indices.data();
-      size_t num_rows = idx_data.shape_.Size();
+      const size_t num_rows = idx_data.shape_.Size();
       recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
-      real_t* data = recv_buf.data().dptr<real_t>();
+      const int dtype = recv_buf.dtype();
+      char* data = static_cast<char *>(recv_buf.data().dptr_);
       const auto offsets = idx_data.dptr<int64_t>();
       const auto unit_len = recv_buf.shape().ProdShape(1, 
recv_buf.shape().ndim());
       const int64_t size = num_rows * unit_len;
+      const int num_bytes = mshadow::mshadow_sizeof(dtype);
       // convert to ps keys in row sparse format
       PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
-                                      unit_len, recv_buf.shape()[0]);
+                                      unit_len, recv_buf.shape()[0],
+                                      num_bytes);
       if (this->log_verbose_) {
         LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << 
" keys: "
                   << pskv.keys << " size: " << size;
       }
-      auto vals = new ps::SArray<real_t>(data, size, false);
+      auto vals = new ps::SArray<char>(data, size * num_bytes, false);
+      const int cmd = GetCommandType(RequestType::kRowSparsePushPull, 
recv_buf.dtype());
       // copy indices to recv_buf. this needs to be done before ZPull
       // because after pull is done, the callback function returns and locks 
are released.
       // at this point, later functions may access the indices variable while 
copy happens
       mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
                     idx_data.FlatTo1D<cpu, int64_t>());
       CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens,
-                                       
static_cast<int>(DataHandleType::kRowSparsePushPull),
+                                       cmd,
                                        [vals, cb]() { delete vals; cb(); });
     };
     CHECK_NOTNULL(Engine::Get())->PushAsync(
@@ -504,67 +512,82 @@ class KVStoreDist : public KVStoreLocal {
   }
 
   /**
-   * \brief convert to keys in ps
+   * \brief convert to pskv for parameter server
+   * \param key
+   * \param num_arr_elems number of elements in the value for key
+   * \param num_bytes size of each element in number of bytes
+   * \return PSKV used for both push and pull
    */
-  inline PSKV& EncodeDefaultKey(int key, size_t size, bool is_push) {
+  inline PSKV& EncodeDefaultKey(const int key, const size_t num_arr_elems,
+                                const int num_bytes) {
     mu_.lock();
     PSKV& pskv = ps_kv_[key];
     mu_.unlock();
+    size_t pskv_size = num_arr_elems * num_bytes;
     if (!pskv.keys.empty()) {
-      CHECK_EQ(static_cast<size_t>(pskv.size), size) << "The value size cannot 
be changed";
+      CHECK_EQ(static_cast<size_t>(pskv.size), pskv_size)
+        << "The value size cannot be changed " << pskv_size << ". Key is " << 
key;
     } else {
       auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
-      int num_servers = krs.size();
+      const int num_servers = krs.size();
       CHECK_GT(num_servers, 0);
 
       // a simple heuristic for load balance
-      if (size < bigarray_bound_) {
+      if (num_arr_elems < bigarray_bound_) {
         // send it to a single random picked server
         int server = (key * 9973) % num_servers;
         ps::Key ps_key = krs[server].begin() + key;
         CHECK_LT(ps_key, krs[server].end());
         pskv.keys.push_back(ps_key);
-        pskv.lens.push_back(size);
-        pskv.size = size;
+        const int total_bytes = num_arr_elems * num_bytes;
+        pskv.lens.push_back(total_bytes);
+        pskv.size = total_bytes;
       } else {
         // parition it to all servers
         pskv.size = 0;
         for (int i = 0; i < num_servers; ++i) {
           size_t part_size =
-            
static_cast<size_t>(round(static_cast<double>(size)/num_servers*(i+1))) -
-            
static_cast<size_t>(round(static_cast<double>(size)/num_servers*i));
+            
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*(i+1)))
 -
+            
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*i));
           ps::Key ps_key = krs[i].begin() + key;
           CHECK_LT(ps_key, krs[i].end());
           pskv.keys.push_back(ps_key);
-          pskv.lens.push_back(part_size);
-          pskv.size += part_size;
+          const int total_bytes = part_size * num_bytes;
+          pskv.lens.push_back(total_bytes);
+          pskv.size += total_bytes;
         }
-        CHECK_EQ(static_cast<size_t>(pskv.size), size);
       }
+      CHECK_EQ(static_cast<size_t>(pskv.size), pskv_size);
     }
     return pskv;
   }
 
   /**
-   * \brief Convert to keys in ps for compressed values
-   * Divides original array into equal parts for each server
-   * Populates both push and pull pskv on first call
+   * \brief Convert to PSKV for pushes and pulls when gradient compression is 
used.
+   * Divides original array into equal parts for each server.
+   * Populates both push and pull pskv on first call.
+   * \param key
+   * \param num_arr_elems number of elements in the value for key
+   * \param is_push whether this is push or pull
+   * \param num_bytes size of each element in number of bytes
+   * \return PSKV used for both push and pull
    */
-  inline PSKV& EncodeCompressedKey(int key, size_t original_size, bool 
is_push) {
+  inline PSKV& EncodeCompressedKey(const int key, const size_t 
original_num_elem,
+                                   const bool is_push, const int num_bytes) {
     auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
-    int num_servers = krs.size();
+    const int num_servers = krs.size();
     CHECK_GT(num_servers, 0);
 
     // represents size of data to be sent
-    size_t compr_size = 
gradient_compression_->GetCompressedSize(original_size);
-
+    size_t compr_num_elem = 
gradient_compression_->GetCompressedSize(original_num_elem);
     mu_.lock();
     PSKV& pskv = (is_push) ? compr_ps_kv_[key].push : compr_ps_kv_[key].pull;
     mu_.unlock();
 
     if (!pskv.keys.empty()) {
-      size_t size = (is_push) ? compr_size : original_size;
-      CHECK_EQ(static_cast<size_t >(pskv.size), size)<< "The value size can't 
be changed";
+      const size_t num_elem = (is_push) ? compr_num_elem : original_num_elem;
+      CHECK_EQ(static_cast<size_t >(pskv.size), num_elem * num_bytes)
+        << "The value size can't be changed. For key " << key;
     } else {
       // populate both pull and push pskvs
       // push pskv has sizes corresponding to compressed data
@@ -574,18 +597,20 @@ class KVStoreDist : public KVStoreLocal {
       PSKV& push_pskv = compr_ps_kv_[key].push;
       mu_.unlock();
 
-      if (original_size < bigarray_bound_) {
+      if (original_num_elem < bigarray_bound_) {
         // a simple heuristic for load balancing
         // send it to a single random picked server
-        int server = (key * 9973) % num_servers;
+        const int server = (key * 9973) % num_servers;
         ps::Key ps_key = krs[server].begin() + key;
         CHECK_LT(ps_key, krs[server].end());
         // meta info
-        push_pskv.keys.push_back(krs[server].begin() + original_size);
+        push_pskv.keys.push_back(krs[server].begin() + original_num_elem);
         push_pskv.lens.push_back(0);
         // data
         push_pskv.keys.push_back(ps_key);
         pull_pskv.keys.push_back(ps_key);
+        const int compr_size = compr_num_elem * num_bytes;
+        const int original_size = original_num_elem * num_bytes;
         push_pskv.lens.push_back(compr_size);
         pull_pskv.lens.push_back(original_size);
         push_pskv.size = compr_size;
@@ -598,12 +623,12 @@ class KVStoreDist : public KVStoreLocal {
         for (int i = 0; i < num_servers; ++i) {
           size_t part_compr, part_orig;
           if (i == num_servers-1) {
-            part_compr = compr_size - push_pskv.size;
-            part_orig = original_size - pull_pskv.size;
+            part_compr = compr_num_elem - push_pskv.size;
+            part_orig = original_num_elem - pull_pskv.size;
           } else {
             part_compr =
-              static_cast<size_t> 
(round(static_cast<double>(compr_size)/num_servers*(i+1))) -
-              static_cast<size_t> 
(round(static_cast<double>(compr_size)/num_servers*(i)));
+              static_cast<size_t> 
(round(static_cast<double>(compr_num_elem)/num_servers*(i+1))) -
+              static_cast<size_t> 
(round(static_cast<double>(compr_num_elem)/num_servers*(i)));
             part_orig = part_compr * 
gradient_compression_->GetCompressionFactor();
           }
 
@@ -618,25 +643,27 @@ class KVStoreDist : public KVStoreLocal {
           CHECK_LT(ps_key, krs[i].end());
           push_pskv.keys.push_back(ps_key);
           pull_pskv.keys.push_back(ps_key);
-          // push_pskv stores lengths of compressed blocks
-          push_pskv.lens.push_back(part_compr);
-          // pull_pskv stores lengths of original data
-          pull_pskv.lens.push_back(part_orig);
+          push_pskv.lens.push_back(part_compr * num_bytes);
+          pull_pskv.lens.push_back(part_orig * num_bytes);
+          // num elements need to be inserted below so that for last server,
+          // there is no round off error
           push_pskv.size += part_compr;
           pull_pskv.size += part_orig;
         }
-        CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_size);
-        CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_size);
-        CHECK_EQ(push_pskv.lens.size(), num_servers*2);
+        CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_num_elem);
+        CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_num_elem);
+        push_pskv.size *= num_bytes;
+        pull_pskv.size *= num_bytes;
+        CHECK_EQ(push_pskv.lens.size(), num_servers * 2);
         }
       }
     return pskv;
   }
 
   // Note: this encoding method for row sparse keys doesn't allow cross-layer 
batching
-  inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const 
int64_t num_rows,
+  inline PSKV& EncodeRowSparseKey(const int key, const int64_t num_elem, const 
int64_t num_rows,
                                   const int64_t *offsets, const size_t 
unit_len,
-                                  const int64_t total_num_rows) {
+                                  const int64_t total_num_rows, const int 
num_bytes) {
     using namespace common;
     mu_.lock();
     PSKV& pskv = ps_kv_[key];
@@ -645,7 +672,7 @@ class KVStoreDist : public KVStoreLocal {
     pskv.lens.clear();
     // TODO(haibin) cache this information
     auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
-    int num_servers = krs.size();
+    const int num_servers = krs.size();
     CHECK_GT(num_servers, 0);
 
     if (total_num_rows * unit_len >= bigarray_bound_) {
@@ -656,7 +683,7 @@ class KVStoreDist : public KVStoreLocal {
         ps::Key master_key = krs[i].begin() + key;
         pskv.keys.push_back(master_key);
         pskv.lens.push_back(0);
-        if (offsets && size > 0) {
+        if (offsets && num_elem > 0) {
           // calculate partition ranges
           int64_t part_num_rows =
             llround(static_cast<double>(total_num_rows) / num_servers * (i + 
1)) -
@@ -669,16 +696,17 @@ class KVStoreDist : public KVStoreLocal {
             ps::Key ps_key = krs[i].begin() + key + (*offset - start_row);
             CHECK_LT(ps_key, krs[i].end());
             pskv.keys.push_back(ps_key);
-            pskv.lens.push_back(unit_len);
-            pskv.size += unit_len;
+            const int part_size = unit_len * num_bytes;
+            pskv.lens.push_back(part_size);
+            pskv.size += (part_size);
           }
           start_row = end_row;
         }
       }
-      CHECK_EQ(static_cast<size_t>(pskv.size), size);
+      CHECK_EQ(static_cast<size_t>(pskv.size), num_elem * num_bytes);
     } else {
       // send it to a single random picked server
-      int server = (key * 9973) % num_servers;
+      const int server = (key * 9973) % num_servers;
       ps::Key master_key = krs[server].begin() + key;
       pskv.keys.push_back(master_key);
       pskv.lens.push_back(0);
@@ -686,9 +714,9 @@ class KVStoreDist : public KVStoreLocal {
         ps::Key ps_key = krs[server].begin() + key + offsets[i];
         CHECK_LT(ps_key, krs[server].end());
         pskv.keys.push_back(ps_key);
-        pskv.lens.push_back(unit_len);
+        pskv.lens.push_back(unit_len * num_bytes);
       }
-      pskv.size = size;
+      pskv.size = num_elem * num_bytes;
     }
     return pskv;
   }
@@ -696,7 +724,7 @@ class KVStoreDist : public KVStoreLocal {
   /**
    * \brief for worker to push and pull data
    */
-  ps::KVWorker<real_t>* ps_worker_;
+  ps::KVWorker<char>* ps_worker_;
   /**
    * \brief the server handle
    */
diff --git a/src/kvstore/kvstore_dist_server.h 
b/src/kvstore/kvstore_dist_server.h
index c2ddcd8708d..421de27b39d 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -40,14 +40,52 @@
 namespace mxnet {
 namespace kvstore {
 
+// maintain same order in frontend.
 enum class CommandType {
-  kController, kStopServer, kSyncMode, kSetGradientCompression
+  kController, kSetMultiPrecision, kStopServer, kSyncMode, 
kSetGradientCompression,
 };
 
-enum class DataHandleType {
-  kDefaultPushPull, kCompressedPushPull, kRowSparsePushPull
+enum class RequestType {
+  kDefaultPushPull, kRowSparsePushPull, kCompressedPushPull
 };
 
+struct DataHandleType {
+  RequestType requestType;
+  int dtype;
+};
+
+/*!
+ * Uses Cantor pairing function to generate a unique number given two numbers.
+ * This number can also be inverted to find the unique pair whose Cantor value 
is this number.
+ * Ref: https://en.wikipedia.org/wiki/Pairing_function#Cantor_pairing_function
+ * \param requestType RequestType
+ * \param dtype integer
+ * \return Cantor value of arguments
+ */
+static int GetCommandType(RequestType requestType, int d) {
+  int m = static_cast<int>(requestType);
+  return (((m + d) * (m + d + 1)) / 2) + d;
+}
+
+/*!
+ * Unpairs Cantor value and finds the two integers used to pair.
+ * Then returns DataHandleType object with those numbers.
+ * \param cmd DataHandleCommand generated by GetCommandType function
+ * \return DataHandleType
+ */
+static DataHandleType DepairDataHandleType(int cmd) {
+  int w = std::floor((std::sqrt(8 * cmd + 1) - 1)/2);
+  int t = ((w * w) + w) / 2;
+  int y = cmd - t;
+  int x = w - y;
+  CHECK_GE(x, 0);
+  CHECK_GE(y, 0);
+  DataHandleType type;
+  type.requestType = static_cast<RequestType>(x);
+  type.dtype = y;
+  return type;
+}
+
 /**
  * \brief executor runs a function using the thread called \ref Start
  */
@@ -114,7 +152,7 @@ class KVStoreDistServer {
  public:
   KVStoreDistServer() {
     using namespace std::placeholders;
-    ps_server_ = new ps::KVServer<float>(0);
+    ps_server_ = new ps::KVServer<char>(0);
     static_cast<ps::SimpleApp*>(ps_server_)->set_request_handle(
         std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2));
     ps_server_->set_request_handle(
@@ -146,9 +184,11 @@ class KVStoreDistServer {
   }
 
  private:
-  struct MergeBuf {
+  struct UpdateBuf {
     std::vector<ps::KVMeta> request;
-    NDArray array;
+    NDArray merged;
+    // temp_array is used to cast received values as float32 for computation 
if required
+    NDArray temp_array;
   };
 
   void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
@@ -159,54 +199,114 @@ class KVStoreDistServer {
       sync_mode_ = true;
     } else if (recved_type == CommandType::kSetGradientCompression) {
       gradient_compression_->DecodeParams(recved.body);
-    } else {
-      // this uses value 0 for message id from frontend
+    } else if (recved_type == CommandType::kSetMultiPrecision) {
+      // uses value 1 for message id from frontend
+      if (!multi_precision_) {
+        multi_precision_ = true;
+        CreateMultiPrecisionCopies();
+      }
+    } else if (recved_type == CommandType::kController) {
+      // value of 0
       // let the main thread to execute ctrl, which is necessary for python
       exec_.Exec([this, recved]() {
           CHECK(controller_);
           controller_(recved.head, recved.body);
         });
+    } else {
+      LOG(FATAL) << "Unknown command type received " << recved.head;
     }
     app->Response(recved);
   }
 
+  /*
+   * For keys already initialized, if necessary create stored_realt.
+   * This will only be used if by some wrong usage of kvstore,
+   * some keys are initialized before optimizer is set.
+   */
+  void CreateMultiPrecisionCopies() {
+    for (auto const& stored_entry : store_) {
+      const int key = stored_entry.first;
+      const NDArray& stored = stored_entry.second;
+      if (stored.dtype() != mshadow::kFloat32) {
+        auto& stored_realt = store_realt_[key];
+        if (stored.storage_type() == kRowSparseStorage) {
+          stored_realt = NDArray(kRowSparseStorage, stored.shape(), 
stored.ctx(),
+                                 true, mshadow::kFloat32);
+        } else {
+          stored_realt = NDArray(stored.shape(), stored.ctx(), false, 
mshadow::kFloat32);
+        }
+
+        auto& update = update_buf_[key];
+        if (!update.merged.is_none()) {
+          if (update.merged.storage_type() == kRowSparseStorage) {
+            update.merged = NDArray(kRowSparseStorage, update.merged.shape(), 
update.merged.ctx(),
+                                    true, mshadow::kFloat32);
+          } else {
+            update.merged = NDArray(update.merged.shape(), 
update.merged.ctx(), false,
+                                    mshadow::kFloat32);
+          }
+        }
+        CHECK(update.request.size() == 0)
+          << ps::MyRank() << "Multiprecision mode can not be set while pushes 
are underway."
+          << "Please set optimizer before pushing keys." << key << " " << 
update.request.size();
+
+        CopyFromTo(stored, stored_realt);
+      }
+    }
+    for (auto const& stored_realt_entry : store_realt_) {
+      stored_realt_entry.second.WaitToRead();
+    }
+  }
+
   void DataHandleEx(const ps::KVMeta& req_meta,
-                    const ps::KVPairs<real_t>& req_data,
-                    ps::KVServer<real_t>* server) {
-    DataHandleType recved_type = static_cast<DataHandleType>(req_meta.cmd);
-    if (recved_type == DataHandleType::kRowSparsePushPull) {
-      DataHandleRowSparse(req_meta, req_data, server);
-    } else if (recved_type == DataHandleType::kCompressedPushPull) {
-      DataHandleCompressed(req_meta, req_data, server);
-    } else {
-      DataHandleDefault(req_meta, req_data, server);
+                    const ps::KVPairs<char>& req_data,
+                    ps::KVServer<char>* server) {
+    DataHandleType type = DepairDataHandleType(req_meta.cmd);
+    switch (type.requestType) {
+      case RequestType::kRowSparsePushPull:
+        DataHandleRowSparse(type, req_meta, req_data, server);
+        break;
+      case RequestType::kCompressedPushPull:
+        DataHandleCompressed(type, req_meta, req_data, server);
+        break;
+      case RequestType::kDefaultPushPull:
+        DataHandleDefault(type, req_meta, req_data, server);
+        break;
     }
-    return;
   }
 
-  inline void ApplyUpdates(const int key, MergeBuf *merged, NDArray *stored,
-                           ps::KVServer<real_t>* server) {
-    if (merged->request.size() == (size_t) ps::NumWorkers()) {
+  inline bool has_multi_precision_copy(const DataHandleType type) {
+    return multi_precision_ && type.dtype != mshadow::kFloat32;
+  }
+
+  inline void ApplyUpdates(const DataHandleType type, const int key,
+                           UpdateBuf *update_buf, ps::KVServer<char>* server) {
+    if (!sync_mode_ || update_buf->request.size() == (size_t) 
ps::NumWorkers()) {
       // let the main thread to execute updater_, which is necessary for python
+      auto& stored = has_multi_precision_copy(type) ? store_realt_[key] : 
store_[key];
+      auto& update =  sync_mode_ ? update_buf->merged : update_buf->temp_array;
       if (updater_) {
-        exec_.Exec([this, key, merged, stored](){
-            CHECK(updater_);
-            updater_(key, merged->array, stored);
-          });
+        exec_.Exec([this, key, &update, &stored](){
+          CHECK(updater_);
+          updater_(key, update, &stored);
+        });
       } else {
+        CHECK(sync_mode_) << "Updater needs to be set for async mode";
         // if no updater, just copy
-        CopyFromTo(merged->array, stored);
+        CopyFromTo(update_buf->merged, &stored);
       }
+
       if (log_verbose_)  {
-        LOG(INFO) << "sync response to " << merged->request.size() << " 
workers";
+        LOG(INFO) << "sent response to " << update_buf->request.size() << " 
workers";
       }
-      for (const auto& req : merged->request) {
+      for (const auto& req : update_buf->request) {
         server->Response(req);
       }
-      merged->request.clear();
-      stored->WaitToRead();
+      update_buf->request.clear();
+      if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
+      stored.WaitToRead();
     } else {
-      merged->array.WaitToRead();
+      update_buf->merged.WaitToRead();
     }
   }
 
@@ -220,175 +320,229 @@ class KVStoreDistServer {
     }
   }
 
-  void DataHandleRowSparse(const ps::KVMeta& req_meta,
-                       const ps::KVPairs<real_t>& req_data,
-                       ps::KVServer<real_t>* server) {
+  void AccumulateRowSparseGrads(const DataHandleType type,
+                                const NDArray& recved,
+                                UpdateBuf* updateBuf) {
+    NDArray out(kRowSparseStorage, updateBuf->merged.shape(), Context(), true,
+                has_multi_precision_copy(type) ? mshadow::kFloat32 : 
type.dtype);
+    if (has_multi_precision_copy(type)) CopyFromTo(recved, 
updateBuf->temp_array);
+    const NDArray& to_merge = has_multi_precision_copy(type) ? 
updateBuf->temp_array : recved;
+    // accumulate row_sparse gradients
+    // TODO(haibin) override + operator for row_sparse NDArray
+    // instead of calling BinaryComputeRspRsp directly
+    using namespace mshadow;
+    Engine::Get()->PushAsync(
+    [to_merge, updateBuf, out](RunContext ctx, Engine::CallbackOnComplete 
on_complete) {
+      op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>(
+      {}, {}, {to_merge, updateBuf->merged}, {kWriteTo}, {out});
+      on_complete();
+    }, to_merge.ctx(), {to_merge.var(), updateBuf->merged.var()}, {out.var()},
+    FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+    CopyFromTo(out, &(updateBuf->merged), 0);
+    updateBuf->merged.WaitToRead();
+  }
+
+  void RowSparsePullResponse(const DataHandleType type,
+                             const int master_key,
+                             const size_t num_rows,
+                             const ps::KVMeta& req_meta,
+                             const ps::KVPairs<char>& req_data,
+                             ps::KVServer<char>* server) {
+    if (log_verbose_) LOG(INFO) << "pull: " << master_key;
+    ps::KVPairs<char> response;
+    if (num_rows == 0) {
+      std::vector<int> lens(req_data.keys.size(), 0);
+      response.keys = req_data.keys;
+      response.lens.CopyFrom(lens.begin(), lens.end());
+      server->Response(req_meta, response);
+      return;
+    }
+    const NDArray& stored = store_[master_key];
+    if (has_multi_precision_copy(type)) stored.WaitToRead();
+    CHECK(!stored.is_none()) << "init " << master_key << " first";
+    auto shape = stored.shape();
+    auto unit_len = shape.ProdShape(1, shape.ndim());
+    const int num_bytes = mshadow::mshadow_sizeof(type.dtype);
+    const int unit_size = unit_len * num_bytes;
+    const char* data = static_cast<char *> (stored.data().dptr_);
+    auto len = num_rows * unit_size;
+    // concat values
+    response.vals.resize(len);
+    #pragma omp parallel for
+    for (size_t i = 1; i <= num_rows; i++) {
+      int key = DecodeKey(req_data.keys[i]);
+      int64_t row_id = key - master_key;
+      const auto src = data + row_id * unit_size;
+      auto begin = (i - 1) * unit_size;
+      auto end = i * unit_size;
+      response.vals.segment(begin, end).CopyFrom(src, unit_size);
+    }
+    // setup response
+    response.keys = req_data.keys;
+    std::vector<int> lens(req_data.keys.size(), unit_len);
+    lens[0] = 0;
+    response.lens.CopyFrom(lens.begin(), lens.end());
+    server->Response(req_meta, response);
+  }
+
+  void InitRowSparseStored(const DataHandleType type,
+                           const int master_key,
+                           const size_t num_rows,
+                           const ps::KVMeta& req_meta,
+                           const ps::KVPairs<char>& req_data,
+                           ps::KVServer<char>* server) {
+    auto& stored = has_multi_precision_copy(type) ? store_realt_[master_key] : 
store_[master_key];
+    int dtype = type.dtype;
+    int num_bytes = mshadow::mshadow_sizeof(dtype);
+    auto unit_len = req_data.lens[1] / num_bytes;
+    CHECK_GT(unit_len, 0);
+    size_t ds[] = {num_rows, (size_t) unit_len};
+    TShape dshape(ds, ds + 2);
+    CHECK_EQ(req_data.vals.size(), num_rows * unit_len * num_bytes);
+    TBlob recv_blob;
+    MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+      recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), 
dshape, cpu::kDevMask);
+    })
+    NDArray recved = NDArray(recv_blob, 0);
+    stored = NDArray(kRowSparseStorage, dshape, Context(), true,
+                     has_multi_precision_copy(type) ? mshadow::kFloat32 : 
type.dtype);
+    if (has_multi_precision_copy(type)) {
+      store_[master_key] = NDArray(kRowSparseStorage, dshape, Context(), true, 
type.dtype);
+    }
+    Engine::Get()->PushAsync(
+    [this, recved, stored, type](RunContext ctx, Engine::CallbackOnComplete 
on_complete) {
+      NDArray rsp = stored;
+      stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
+      mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+      using namespace mxnet::op;
+      nnvm::dim_t nnr = rsp.shape()[0];
+      MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
+        IType* idx = rsp.aux_data(rowsparse::kIdx).dptr<IType>();
+        mxnet_op::Kernel<PopulateFullIdxRspKernel, cpu>::Launch(s, nnr, idx);
+      });
+      TBlob rsp_data = rsp.data();
+      // copies or casts as appropriate
+      ndarray::Copy<cpu, cpu>(recved.data(), &rsp_data, Context(), Context(), 
RunContext());
+      on_complete();
+    }, recved.ctx(), {recved.var()}, {stored.var()},
+    FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+    if (has_multi_precision_copy(type)) {
+      CopyFromTo(stored, store_[master_key]);
+      store_[master_key].WaitToRead();
+    }
+    stored.WaitToRead();
+    server->Response(req_meta);
+  }
+
+  void DataHandleRowSparse(const DataHandleType type, const ps::KVMeta& 
req_meta,
+                           const ps::KVPairs<char>& req_data,
+                           ps::KVServer<char>* server) {
     int master_key = DecodeKey(req_data.keys[0]);
     auto num_rows = req_data.keys.size() - 1;
     auto& stored = store_[master_key];
     if (req_meta.push) {
       CHECK_GT(req_data.lens.size(), 0) << "req_data.lens cannot be empty";
       CHECK_EQ(req_data.lens[0], 0);
-      real_t* data = req_data.vals.data();
       if (stored.is_none()) {
         if (log_verbose_) LOG(INFO) << "initial push: " << master_key;
         // initialization
         CHECK_GT(num_rows, 0) << "init with empty data is not supported";
-        auto unit_len = req_data.lens[1];
-        CHECK_GT(unit_len, 0);
-        size_t ds[] = {num_rows, (size_t) unit_len};
-        TShape dshape(ds, ds + 2);
-        CHECK_EQ(req_data.vals.size(), num_rows * unit_len);
-        TBlob recv_blob(data, dshape, cpu::kDevMask);  // NOLINT(*)
-        NDArray recved = NDArray(recv_blob, 0);
-        stored = NDArray(kRowSparseStorage, dshape, Context());
-        Engine::Get()->PushAsync(
-          [recved, stored](RunContext ctx, Engine::CallbackOnComplete 
on_complete) {
-            NDArray rsp = stored;
-            stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
-            mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
-            using namespace mxnet::op;
-            nnvm::dim_t nnr = rsp.shape()[0];
-            MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
-              IType* idx = rsp.aux_data(rowsparse::kIdx).dptr<IType>();
-              mxnet_op::Kernel<PopulateFullIdxRspKernel, cpu>::Launch(s, nnr, 
idx);
-            });
-            mshadow::Copy(rsp.data().FlatTo1D<cpu, float>(),
-                          recved.data().FlatTo1D<cpu, float>(), s);
-            on_complete();
-          }, recved.ctx(), {recved.var()}, {stored.var()},
-          FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
-        stored.WaitToRead();
-        server->Response(req_meta);
+        InitRowSparseStored(type, master_key, num_rows, req_meta, req_data, 
server);
         return;
-      }
-      // synced push
-      if (sync_mode_) {
-        if (log_verbose_) LOG(INFO) << "sync push: " << master_key << " " << 
req_data.keys;
-        auto& merged = merge_buf_[master_key];
-        if (merged.array.is_none()) {
-          merged.array = NDArray(kRowSparseStorage, stored.shape(), Context());
+      } else {
+        if (log_verbose_) LOG(INFO) << "push: " << master_key << " " << 
req_data.keys;
+        auto& updates = update_buf_[master_key];
+        if (sync_mode_ && updates.merged.is_none()) {
+          updates.merged = NDArray(kRowSparseStorage, stored.shape(), 
Context(), true,
+                                   has_multi_precision_copy(type) ? 
mshadow::kFloat32 : type.dtype);
         }
+        if (has_multi_precision_copy(type) && updates.temp_array.is_none()) {
+          updates.temp_array = NDArray(kRowSparseStorage, stored.shape(), 
Context(), false,
+                                       mshadow::kFloat32);
+        }
+
         if (num_rows == 0) {
-          // reset to zeros
-          if (merged.request.size() == 0) {
-            merged.array = NDArray(kRowSparseStorage, stored.shape(), 
Context());
+          if (sync_mode_) {
+            if (updates.request.empty()) {
+              // reset to zeros
+              int merged_dtype = has_multi_precision_copy(type) ? 
mshadow::kFloat32 : type.dtype;
+              updates.merged = NDArray(kRowSparseStorage, stored.shape(), 
Context(),
+                                       true, merged_dtype);
+            }  // else nothing to aggregate
+            updates.request.push_back(req_meta);
+            ApplyUpdates(type, master_key, &updates, server);
           } else {
-            // nothing to aggregate
+            server->Response(req_meta);
           }
-          merged.request.push_back(req_meta);
-          ApplyUpdates(master_key, &merged,  &stored, server);
-          return;
-        }
-        auto unit_len = req_data.lens[1];
-        CHECK_GT(unit_len, 0);
-        // indices
-        std::vector<int64_t> indices(num_rows);
-        DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);
-        // data
-        TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), 
cpu::kDevMask);
-        size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
-        TShape dshape(ds, ds + 2);
-        TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*)
-        // row_sparse NDArray
-        NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, 
{idx_blob}, 0);
-
-        if (merged.request.size() == 0) {
-          CopyFromTo(recved, &merged.array, 0);
         } else {
-          NDArray out(kRowSparseStorage, stored.shape(), Context());
-          // accumulate row_sparse gradients
-          // TODO(haibin) override + operator for row_sparse NDArray
-          // instead of calling BinaryComputeRspRsp directly
-          using namespace mshadow;
-          Engine::Get()->PushAsync(
-            [recved, merged, out](RunContext ctx, Engine::CallbackOnComplete 
on_complete) {
-              op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>(
-                {}, {}, {recved, merged.array}, {kWriteTo}, {out});
-              on_complete();
-            }, recved.ctx(), {recved.var(), merged.array.var()}, {out.var()},
-            FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
-          CopyFromTo(out, &merged.array, 0);
-        }
-        merged.request.push_back(req_meta);
-        ApplyUpdates(master_key, &merged,  &stored, server);
-      } else {
-        // async push
-        if (log_verbose_) LOG(INFO) << "async push: " << master_key;
-        if (num_rows == 0) {
-          server->Response(req_meta);
-          return;
+          auto unit_len = req_data.lens[1] / 
mshadow::mshadow_sizeof(type.dtype);
+          CHECK_GT(unit_len, 0);
+          // indices
+          std::vector<int64_t> indices(num_rows);
+          DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);
+
+          // data
+          TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), 
cpu::kDevMask);
+          size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
+          TShape dshape(ds, ds + 2);
+          TBlob recv_blob;
+          MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
+            recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()),
+                              dshape, cpu::kDevMask);
+          })
+          // row_sparse NDArray
+          NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, 
{idx_blob}, 0);
+
+          if (updates.request.empty()) {
+            if (sync_mode_) {
+              CopyFromTo(recved, updates.merged);
+            } else {
+              if (has_multi_precision_copy(type)) {
+                CopyFromTo(recved, updates.temp_array);
+              } else {
+                updates.temp_array = recved;
+              }
+            }
+          } else {
+            CHECK(sync_mode_);
+            AccumulateRowSparseGrads(type, recved, &updates);
+          }
+          updates.request.push_back(req_meta);
+          ApplyUpdates(type, master_key, &updates, server);
         }
-        auto unit_len = req_data.lens[1];
-        CHECK_GT(unit_len, 0);
-        // indices
-        std::vector<int64_t> indices(num_rows);
-        DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows);
-        TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), 
cpu::kDevMask);
-        size_t ds[] = {(size_t) num_rows, (size_t) unit_len};
-        TShape dshape(ds, ds + 2);
-        TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*)
-        NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, 
{idx_blob}, 0);
-        exec_.Exec([this, master_key, &recved, &stored](){
-            CHECK(updater_);
-            updater_(master_key, recved, &stored);
-          });
-        server->Response(req_meta);
-        stored.WaitToRead();
       }
     } else {
       // pull
-      if (log_verbose_) LOG(INFO) << "pull: " << master_key;
-      ps::KVPairs<real_t> response;
-      if (num_rows == 0) {
-        std::vector<int> lens(req_data.keys.size(), 0);
-        response.keys = req_data.keys;
-        response.lens.CopyFrom(lens.begin(), lens.end());
-        server->Response(req_meta, response);
-        return;
-      }
-      CHECK(!stored.is_none()) << "init " << master_key << " first";
-      auto shape = stored.shape();
-      auto unit_len = shape.ProdShape(1, shape.ndim());
-      const float* data = stored.data().dptr<float>();
-      auto len = unit_len * num_rows;
-      // concat values
-      response.vals.resize(len);
-      #pragma omp parallel for
-      for (size_t i = 1; i <= num_rows; i++) {
-        int key = DecodeKey(req_data.keys[i]);
-        int64_t row_id = key - master_key;
-        const auto src = data + row_id * unit_len;
-        auto begin = (i - 1) * unit_len;
-        auto end = i * unit_len;
-        response.vals.segment(begin, end).CopyFrom(src, unit_len);
-      }
-      // setup response
-      response.keys = req_data.keys;
-      std::vector<int> lens(req_data.keys.size(), unit_len);
-      lens[0] = 0;
-      response.lens.CopyFrom(lens.begin(), lens.end());
-      server->Response(req_meta, response);
+      RowSparsePullResponse(type, master_key, num_rows, req_meta, req_data, 
server);
     }
   }
 
-  void DefaultStorageResponse(int key, const NDArray& stored,
+  void DefaultStorageResponse(const DataHandleType type,
+                              const int key,
                               const ps::KVMeta& req_meta,
-                              const ps::KVPairs<real_t> &req_data,
-                              ps::KVServer<real_t>* server) {
-    ps::KVPairs<real_t> response;
+                              const ps::KVPairs<char> &req_data,
+                              ps::KVServer<char>* server) {
+    ps::KVPairs<char> response;
+    const NDArray& stored = store_[key];
     CHECK(!stored.is_none()) << "init " << key << " first";
-    auto len = stored.shape().Size();
+
+    // as server returns when store_realt is ready in this case
+    if (has_multi_precision_copy(type)) stored.WaitToRead();
+
+    auto len = stored.shape().Size() * mshadow::mshadow_sizeof(stored.dtype());
     response.keys = req_data.keys;
     response.lens = {len};
     // TODO(mli) try to remove this CopyFrom
-    response.vals.CopyFrom(static_cast<const float*>(stored.data().dptr_), 
len);
+    response.vals.CopyFrom(static_cast<const char*>(stored.data().dptr_), len);
     server->Response(req_meta, response);
   }
 
-  void DataHandleCompressed(const ps::KVMeta& req_meta,
-                            const ps::KVPairs<real_t> &req_data,
-                            ps::KVServer<real_t>* server) {
+  void DataHandleCompressed(const DataHandleType type,
+                            const ps::KVMeta& req_meta,
+                            const ps::KVPairs<char> &req_data,
+                            ps::KVServer<char>* server) {
+    CHECK_EQ(type.dtype, mshadow::kFloat32)
+      << "Gradient compression is currently supported for fp32 only";
     if (req_meta.push) {
       // there used several WaitToRead, this is because \a recved's memory
       // could be deallocated when this function returns. so we need to make 
sure
@@ -403,10 +557,9 @@ class KVStoreDistServer {
       int key = DecodeKey(req_data.keys[1]);
       auto& stored = store_[key];
 
-      size_t ds[] = {(size_t)req_data.lens[1]};
+      size_t ds[] = {(size_t)req_data.lens[1] / 
mshadow::mshadow_sizeof(type.dtype)};
       TShape dshape(ds, ds + 1);
-      TBlob recv_blob((real_t*) req_data.vals.data(), // NOLINT(*)
-                      dshape, cpu::kDevMask);
+      TBlob recv_blob(reinterpret_cast<real_t*>(req_data.vals.data()), dshape, 
cpu::kDevMask);
       NDArray recved = NDArray(recv_blob, 0);
 
       NDArray decomp_buf = decomp_buf_[key];
@@ -423,18 +576,18 @@ class KVStoreDistServer {
         stored.WaitToRead();
       } else if (sync_mode_) {
         // synced push
-        auto& merged = merge_buf_[key];
-        if (merged.array.is_none()) {
-          merged.array = NDArray(dshape, Context());
+        auto& merged = update_buf_[key];
+        if (merged.merged.is_none()) {
+          merged.merged = NDArray(dshape, Context());
         }
         if (merged.request.size() == 0) {
-          gradient_compression_->Dequantize(recved, &merged.array, 0);
+          gradient_compression_->Dequantize(recved, &merged.merged, 0);
         } else {
           gradient_compression_->Dequantize(recved, &decomp_buf, 0);
-          merged.array += decomp_buf;
+          merged.merged += decomp_buf;
         }
         merged.request.push_back(req_meta);
-        ApplyUpdates(key, &merged, &stored, server);
+        ApplyUpdates(type, key, &merged, server);
       } else {
         // async push
         gradient_compression_->Dequantize(recved, &decomp_buf, 0);
@@ -449,63 +602,78 @@ class KVStoreDistServer {
       CHECK_EQ(req_data.keys.size(), (size_t)1);
       CHECK_EQ(req_data.lens.size(), (size_t)0);
       int key = DecodeKey(req_data.keys[0]);
-      DefaultStorageResponse(key, store_[key], req_meta, req_data, server);
+      DefaultStorageResponse(type, key, req_meta, req_data, server);
     }
   }
 
-  void DataHandleDefault(const ps::KVMeta& req_meta,
-                         const ps::KVPairs<real_t> &req_data,
-                         ps::KVServer<real_t>* server) {
-    CHECK_EQ(req_meta.cmd, static_cast<int>(DataHandleType::kDefaultPushPull));
+  void DataHandleDefault(const DataHandleType type, const ps::KVMeta& req_meta,
+                         const ps::KVPairs<char> &req_data,
+                         ps::KVServer<char>* server) {
     // do some check
     CHECK_EQ(req_data.keys.size(), (size_t)1);
     if (req_meta.push) {
       CHECK_EQ(req_data.lens.size(), (size_t)1);
       CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]);
     }
-
     int key = DecodeKey(req_data.keys[0]);
-    auto& stored = store_[key];
-
+    auto& stored = has_multi_precision_copy(type) ? store_realt_[key] : 
store_[key];
     // there used several WaitToRead, this is because \a recved's memory
     // could be deallocated when this function returns. so we need to make sure
     // the operators with \a NDArray are actually finished
     if (req_meta.push) {
-      size_t ds[] = {(size_t)req_data.lens[0]};
+      size_t ds[] = {(size_t) req_data.lens[0] / 
mshadow::mshadow_sizeof(type.dtype)};
       TShape dshape(ds, ds + 1);
-      TBlob recv_blob((real_t*)req_data.vals.data(), // NOLINT(*)
-                      dshape, cpu::kDevMask);
+      TBlob recv_blob;
+      MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
+        recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), 
dshape, cpu::kDevMask);
+      })
       NDArray recved = NDArray(recv_blob, 0);
       if (stored.is_none()) {
         // initialization
-        stored = NDArray(dshape, Context());
+        stored = NDArray(dshape, Context(), false,
+                         has_multi_precision_copy(type) ? mshadow::kFloat32 : 
type.dtype);
         CopyFromTo(recved, &stored, 0);
         server->Response(req_meta);
+        if (has_multi_precision_copy(type)) {
+          auto& stored_dtype = store_[key];
+          stored_dtype = NDArray(dshape, Context(), false, type.dtype);
+          CopyFromTo(stored, stored_dtype);
+          stored_dtype.WaitToRead();
+        }
         stored.WaitToRead();
-      } else if (sync_mode_) {
-        // synced push
-        auto& merged = merge_buf_[key];
-        if (merged.array.is_none()) {
-          merged.array = NDArray(dshape, Context());
+      } else {
+        auto &updates = update_buf_[key];
+        if (sync_mode_ && updates.merged.is_none()) {
+          updates.merged = NDArray(dshape, Context(), false,
+                                   has_multi_precision_copy(type) ? 
mshadow::kFloat32 : type.dtype);
         }
-        if (merged.request.size() == 0) {
-          CopyFromTo(recved, &merged.array, 0);
+        if (has_multi_precision_copy(type) && updates.temp_array.is_none()) {
+          updates.temp_array = NDArray(dshape, Context(), false, 
mshadow::kFloat32);
+        }
+        if (updates.request.empty()) {
+          if (sync_mode_) {
+            CopyFromTo(recved, updates.merged);
+          } else {
+            if (has_multi_precision_copy(type)) {
+              CopyFromTo(recved, updates.temp_array);
+            } else {
+              updates.temp_array = recved;
+            }
+          }
         } else {
-          merged.array += recved;
+          CHECK(sync_mode_);
+          if (has_multi_precision_copy(type)) {
+            CopyFromTo(recved, updates.temp_array);
+            updates.merged += updates.temp_array;
+          } else {
+            updates.merged += recved;
+          }
         }
-        merged.request.push_back(req_meta);
-        ApplyUpdates(key, &merged, &stored, server);
-      } else {
-        // async push
-        exec_.Exec([this, key, &recved, &stored](){
-            CHECK(updater_);
-            updater_(key, recved, &stored);
-          });
-        server->Response(req_meta);
-        stored.WaitToRead();
+        updates.request.push_back(req_meta);
+        ApplyUpdates(type, key, &updates, server);
       }
     } else {
-      DefaultStorageResponse(key, stored, req_meta, req_data, server);
+      DefaultStorageResponse(type, key, req_meta, req_data, server);
     }
   }
 
@@ -526,13 +694,14 @@ class KVStoreDistServer {
    * \brief store_ contains the value at kvstore for each key
    */
   std::unordered_map<int, NDArray> store_;
+  std::unordered_map<int, NDArray> store_realt_;
 
   /**
    * \brief merge_buf_ is a buffer used if sync_mode is true. It represents
    * values from different workers being merged. The store will be updated
    * to this value when values from all workers are pushed into this buffer.
    */
-  std::unordered_map<int, MergeBuf> merge_buf_;
+  std::unordered_map<int, UpdateBuf> update_buf_;
 
   /**
    * \brief decomp_buf_ is a buffer into which compressed values are
@@ -541,11 +710,18 @@ class KVStoreDistServer {
   std::unordered_map<int, NDArray> decomp_buf_;
 
   Executor exec_;
-  ps::KVServer<float>* ps_server_;
+  ps::KVServer<char>* ps_server_;
 
   // whether to LOG verbose information
   bool log_verbose_;
 
+  /*
+   * \brief whether to use multi precision mode.
+   * in multi precision mode, all weights are stored as float32.
+   * any gradient received will be cast to float32 before accumulation and 
updating of weights.
+   */
+  bool multi_precision_;
+
   /**
    * \brief gradient compression object.
    * starts with none, used after SetGradientCompression sets the type
diff --git a/tests/nightly/dist_sync_kvstore.py 
b/tests/nightly/dist_sync_kvstore.py
index 3a3c916d782..3bf5cbffa13 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -20,132 +20,169 @@
 # pylint: skip-file
 import sys
 sys.path.insert(0, "../../python/")
+import argparse
 import mxnet as mx
 import numpy as np
 import numpy.random as rnd
 from mxnet.test_utils import assert_almost_equal
 from test_kvstore import compute_expected_2bit_quantization
 
-def check_diff_to_scalar(A, x, rank=None):
-    """ assert A == x"""
-    assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x)
+def check_diff(A, x, rank=None):
+    """ assert A == x
+        x can be scalar as well as numpy array
+    """
+    assert (np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), 
x.asnumpy())
 
 # setup
-keys = ['3', '5', '7']
-rsp_keys = ['9', '11', '13']
-init_test_keys = [str(i) for i in range(200,300)]
-init_test_keys_big = [str(i) for i in range(300,400)]
-init_test_keys_device = [str(i) for i in range(400,500)]
-init_test_keys_device_big = [str(i) for i in range(500,600)]
-
-rate = 2
 shape = (2, 3)
 irregular_shape = (1211,1211)
 big_shape = (1200, 1200)        # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
 
+keys_shape = ['3', '5', '7']
+keys_big_shape = ['99']
+fp16_keys_shape = ['4', '6', '8']
+fp16_keys_big_shape = ['100']
+
+rsp_keys_shape = ['9', '11', '13']
+rsp_keys_big_shape = ['97']
+fp16_rsp_keys_shape = ['10', '12', '14']
+fp16_rsp_keys_big_shape = ['98']
+
+keys_shapes = [(k, shape) for k in keys_shape] + [(k, big_shape) for k in 
keys_big_shape]
+fp16_keys_shapes = [(k, shape) for k in fp16_keys_shape] + [(k, big_shape) for 
k in fp16_keys_big_shape]
+
+init_test_keys = [str(i) for i in range(200, 300)]
+init_test_keys_big = [str(i) for i in range(300, 400)]
+init_test_keys_device = [str(i) for i in range(400, 500)]
+init_test_keys_device_big = [str(i) for i in range(500, 600)]
+
+compr_keys_shapes = [('1000', shape), ('1200', irregular_shape),('1300', 
big_shape)]
+compr_init_keys_shapes = [('1001', shape), ('1201', irregular_shape),('1301', 
big_shape)]
+compr_random_keys_shapes = [('1002', shape),('1202', irregular_shape),('1302', 
big_shape)]
+
+rate = 2
+
 kv = mx.kv.create('dist_sync')
 
+my_rank = kv.rank
+nworker = kv.num_workers
+
 def init_kv():
-    # init kv dns keys
-    kv.init(keys, [mx.nd.ones(shape)] * len(keys))
-    kv.init('99', mx.nd.ones(big_shape))
-    # init kv row_sparse keys
-    kv.init(rsp_keys, [mx.nd.ones(shape).tostype('row_sparse')] * 
len(rsp_keys))
-    kv.init('100', mx.nd.ones(big_shape).tostype('row_sparse'))
-    # worker info
-    my_rank = kv.rank
-    nworker = kv.num_workers
+    # # init kv dns keys
+    kv.init(keys_shape, [mx.nd.ones(shape)] * len(keys_shape))
+    kv.init(keys_big_shape, [mx.nd.ones(big_shape)] * len(keys_big_shape))
+    # # init kv row_sparse keys
+    kv.init(rsp_keys_shape, [mx.nd.ones(shape).tostype('row_sparse')] * 
len(rsp_keys_shape))
+    kv.init(rsp_keys_big_shape, [mx.nd.ones(big_shape).tostype('row_sparse')] 
* len(rsp_keys_big_shape))
+    # init fp16 dns keys
+    kv.init(fp16_keys_shape, [mx.nd.ones(shape, dtype='float16')] * 
len(keys_shape))
+    kv.init(fp16_keys_big_shape, [mx.nd.ones(big_shape, dtype='float16')] * 
len(keys_big_shape))
+    # init fp16 row_sparse keys
+    kv.init(fp16_rsp_keys_shape, [mx.nd.ones(shape, 
dtype='float16').tostype('row_sparse')] * len(fp16_rsp_keys_shape))
+    kv.init(fp16_rsp_keys_big_shape, [mx.nd.ones(big_shape, 
dtype='float16').tostype('row_sparse')] * len(fp16_rsp_keys_big_shape))
+    return kv
+
+def set_optimizer(use_multiprecision):
     # init updater on servers
-    kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate))
-    return kv, my_rank, nworker
+    kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate, 
multi_precision=use_multiprecision))
+    return kv
 
 def init_kv_compressed(kv):
     threshold = 0.5
-    kv.set_gradient_compression({'type': '2bit', 'threshold':threshold})
+    kv.set_gradient_compression({'type': '2bit', 'threshold': threshold})
     # init kv compression keys
-    kv.init('11221', mx.nd.zeros(big_shape))
-    kv.init('112221', mx.nd.zeros(irregular_shape))
-    kv.init('1121', mx.nd.zeros(shape))
+    for k, s in compr_keys_shapes:
+        kv.init(k, mx.nd.zeros(s))
     # to test inactive mode
-    kv.init('1122', mx.nd.ones(shape))
+    for k, s in compr_init_keys_shapes:
+        kv.init(k, mx.nd.ones(s))
     return kv, threshold
 
-def test_sync_push_pull():
-    kv, my_rank, nworker = init_kv()
-    def check_default_keys(kv, my_rank, nworker):
-        nrepeat = 3
+def test_sync_push_pull(nrepeat):
+    def check_default_keys(dtype, nrepeat):
         # checks pull after push in loop, because behavior during
         # consecutive pushes doesn't offer any guarantees
-        for i in range(nrepeat):
-            kv.push('3', mx.nd.ones(shape)*(my_rank+1))
-            kv.push('99', mx.nd.ones(big_shape)*(my_rank+1))
-            num = (nworker + 1) * nworker * rate / 2 * (i + 1) + 1
-            val = mx.nd.zeros(shape)
-            kv.pull('3', out=val)
-            check_diff_to_scalar(val, num)
-            val2 = mx.nd.zeros(big_shape)
-            kv.pull('99', out=val2)
-            check_diff_to_scalar(val2, num)
-
-    def check_row_sparse_keys(kv, my_rank, nworker):
-        nrepeat = 3
+        ks = keys_shapes if dtype == 'float32' else fp16_keys_shapes
+        for k, s in ks:
+            for i in range(nrepeat):
+                kv.push(k, mx.nd.ones(s, dtype=dtype)*(my_rank+1))
+                num = (nworker + 1) * nworker * rate / 2 * (i + 1) + 1
+                val = mx.nd.zeros(s, dtype=dtype)
+                kv.pull(k, out=val)
+                check_diff(val, num)
+
+    def check_row_sparse_keys(dtype, nrepeat):
         # prepare gradient
-        v = mx.nd.zeros(shape)
+        v = mx.nd.zeros(shape, dtype=dtype)
         my_row = my_rank % shape[0]
         v[my_row] = my_rank + 1
         # push
+        if dtype == 'float32':
+            k = rsp_keys_shape[0]
+        else:
+            k = fp16_rsp_keys_shape[0]
+        s = shape
         for i in range(nrepeat):
-            kv.push('9', v.tostype('row_sparse'))
+            kv.push(k, v.tostype('row_sparse'))
             # select a random subset of rows this worker is interested in
-            num_rows = shape[0]
+            num_rows = s[0]
             row_ids_np = np.random.randint(num_rows, size=num_rows)
-            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2))
+            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 
2)).astype(dtype)
             # perform pull
-            val = mx.nd.zeros(shape, stype='row_sparse')
-            kv.row_sparse_pull('9', out=val, row_ids=row_ids)
+            val = mx.nd.zeros(s, stype='row_sparse', dtype=dtype)
+            kv.row_sparse_pull(k, out=val, row_ids=row_ids)
             # prepare updated values
-            updated_val = mx.nd.ones(shape)
+            updated_val = mx.nd.ones(s, dtype=dtype)
             for rank in range(nworker):
-                row = rank % shape[0]
+                row = rank % s[0]
                 updated_val[row] += (rank + 1) * rate * (i+1)
             # verify subset of updated values
-            expected = mx.nd.zeros(shape)
+            expected = mx.nd.zeros(s, dtype=dtype)
             for row in row_ids_np:
                 expected[row] = updated_val[row]
-            check_diff_to_scalar(val, expected)
+            check_diff(val, expected, kv.rank)
 
-    def check_row_sparse_keys_with_zeros(kv, my_rank, nworker):
-        nrepeat = 3
+    def check_row_sparse_keys_with_zeros(dtype, nrepeat):
+        if dtype == 'float32':
+            k1 = rsp_keys_shape[1]
+            k2 = rsp_keys_big_shape[0]
+        else:
+            k1 = fp16_rsp_keys_shape[1]
+            k2 = fp16_rsp_keys_big_shape[0]
         # prepare gradient
-        v = mx.nd.sparse.zeros('row_sparse', shape)
-        big_v = mx.nd.sparse.zeros('row_sparse', big_shape)
+        v = mx.nd.sparse.zeros('row_sparse', shape, dtype=dtype)
+        big_v = mx.nd.sparse.zeros('row_sparse', big_shape, dtype=dtype)
         # push
         for i in range(nrepeat):
-            kv.push('11', v)
-            kv.push('100', big_v)
+            kv.push(k1, v)
+            kv.push(k2, big_v)
             # pull a subset of rows this worker is interested in
             all_row_ids = np.arange(shape[0])
             val = mx.nd.sparse.zeros('row_sparse', shape)
             big_val = mx.nd.sparse.zeros('row_sparse', big_shape)
-            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids))
+            kv.row_sparse_pull(k1, out=val, row_ids=mx.nd.array(all_row_ids))
             big_all_row_ids = np.arange(big_shape[0])
-            kv.row_sparse_pull('100', out=big_val, 
row_ids=mx.nd.array(big_all_row_ids))
+            kv.row_sparse_pull(k2, out=big_val, 
row_ids=mx.nd.array(big_all_row_ids))
             # verify results
-            check_diff_to_scalar(val, 1)
-            check_diff_to_scalar(big_val, 1)
+            check_diff(val, 1)
+            check_diff(big_val, 1)
             # pull empty weights
-            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array([]))
-            kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array([]))
-            check_diff_to_scalar(val, 0)
-            check_diff_to_scalar(big_val, 0)
+            kv.row_sparse_pull(k1, out=val, row_ids=mx.nd.array([]))
+            kv.row_sparse_pull(k2, out=big_val, row_ids=mx.nd.array([]))
+            check_diff(val, 0)
+            check_diff(big_val, 0)
+
+    def check_big_row_sparse_keys(dtype, nrepeat):
+        if dtype == 'float32':
+            k = rsp_keys_big_shape[0]
+        else:
+            k = fp16_rsp_keys_big_shape[0]
 
-    def check_big_row_sparse_keys(kv, my_rank, nworker):
         mx.random.seed(123)
         rnd.seed(123)
         density = 0.3
-        nrepeat = 3
         # prepare gradient
-        v = mx.nd.zeros(big_shape)
+        v = mx.nd.zeros(big_shape, dtype=dtype)
         idx_sample = rnd.rand(big_shape[0])
         indices = np.argwhere(idx_sample < density).flatten()
         # each worker chooses a subset of the indices to update
@@ -163,98 +200,102 @@ def check_big_row_sparse_keys(kv, my_rank, nworker):
             v[row] = my_rank + 1
         # push
         for i in range(nrepeat):
-            kv.push('100', v.tostype('row_sparse'))
-
+            kv.push(k, v.tostype('row_sparse'))
             # select a random subset of rows this worker is interested in
             mx.random.seed(my_rank)
             rnd.seed(my_rank)
             num_rows = big_shape[0]
             row_ids_np = np.random.randint(num_rows, size=num_rows)
-            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 2))
+            row_ids = mx.nd.array(row_ids_np).reshape((num_rows/2, 
2)).astype(dtype)
             # perform pull
-            val = mx.nd.zeros(big_shape, stype='row_sparse')
-            kv.row_sparse_pull('100', out=val, row_ids=row_ids)
+            val = mx.nd.zeros(big_shape, stype='row_sparse', dtype=dtype)
+            kv.row_sparse_pull(k, out=val, row_ids=row_ids)
             # prepare expected result
-            updated_val = mx.nd.ones(big_shape)
+            updated_val = mx.nd.ones(big_shape, dtype=dtype)
             # apply updates from each worker
             for rank in range(nworker):
                 for row in update_rows[rank]:
                     updated_val[row] += (rank + 1) * rate * (i+1)
 
-            expected = mx.nd.zeros(big_shape)
+            expected = mx.nd.zeros(big_shape, dtype=dtype)
             for row in row_ids_np:
                 expected[row] = updated_val[row]
-            check_diff_to_scalar(val, expected, rank=my_rank)
+            check_diff(val, expected.astype(dtype), rank=my_rank)
 
-    def check_compr_residual(kv, threshold, nworker):
-        for k,s in [('1121', shape),('112221',irregular_shape),('11221', 
big_shape)]:
+    for dtype in ['float16', 'float32']:
+        check_default_keys(dtype, nrepeat)
+        check_row_sparse_keys(dtype, nrepeat)
+        check_row_sparse_keys_with_zeros(dtype, nrepeat)
+        check_big_row_sparse_keys(dtype, nrepeat)
+    print('worker ' + str(my_rank) + ' is done with non compression tests')
+
+def test_sync_2bit_compression(threshold, nrepeat):
+    def check_compr_residual(threshold):
+        for k, s in compr_keys_shapes:
             # doesn't meet threshold
-            kv.push(k, mx.nd.ones(s)*0.4)
-            val=mx.nd.zeros(s)
+            kv.push(k, mx.nd.ones(s) * 0.4)
+            val = mx.nd.zeros(s)
             kv.pull(k,val)
-            check_diff_to_scalar(val, 0)
+            check_diff(val, 0)
 
             # just meets threshold with residual
-            kv.push(k, mx.nd.ones(s)*(threshold - 0.4))
+            kv.push(k, mx.nd.ones(s) * (threshold - 0.4))
             val2 = mx.nd.zeros(s)
             kv.pull(k,val2)
             curval = threshold * rate * nworker
-            check_diff_to_scalar(val2, curval)
+            check_diff(val2, curval)
 
             # doesn't meet threshold
-            kv.push(k, mx.nd.ones(s)*0.2)
-            val3= mx.nd.zeros(s)
+            kv.push(k, mx.nd.ones(s) * 0.2)
+            val3 = mx.nd.zeros(s)
             kv.pull(k, val3)
-            check_diff_to_scalar(val3, curval)
+            check_diff(val3, curval)
 
             # exceeds again
-            kv.push(k, mx.nd.ones(s)*(threshold-0.2))
+            kv.push(k, mx.nd.ones(s) * (threshold-0.2))
             val4 = mx.nd.zeros(s)
-            kv.pull(k,val4)
-            curval += threshold*rate*nworker
-            check_diff_to_scalar(val4, curval)
+            kv.pull(k, val4)
+            curval += threshold * rate * nworker
+            check_diff(val4, curval)
             # residual is 0 now
 
-    def check_compr_ones(kv, threshold, nworker):
-        for k,s in [('1121', shape),('112221',irregular_shape),('11221', 
big_shape)]:
+    def check_compr_ones(threshold):
+        for k, s in compr_keys_shapes:
             val = mx.nd.zeros(s)
             kv.pull(k, val)
             curval = val[0][0].asnumpy()[0]
-            kv.push(k,mx.nd.ones(s)*threshold)
+            kv.push(k,mx.nd.ones(s) * threshold)
             val2 = mx.nd.zeros(s)
             kv.pull(k, val2)
-            newval = curval + rate*nworker*threshold
-            check_diff_to_scalar(val2, newval)
+            newval = curval + rate * nworker * threshold
+            check_diff(val2, newval)
             # residual = 0  again
 
-    def check_compr_pull_before_push(kv):
-        for k,s in [('1121', shape),('112221',irregular_shape),
-                    ('11221', big_shape), ('1122',shape)]:
-            if k=='1122':
-                # tests that GC is not used for init of a key
-                val = mx.nd.zeros(s)
-                kv.pull(k, val)
-                check_diff_to_scalar(val, 1)
-            else:
-                val = mx.nd.ones(s)
-                kv.pull(k, val)
-                check_diff_to_scalar(val, 0)
+    def check_compr_pull_before_push():
+        for k,s in compr_keys_shapes:
+            val = mx.nd.ones(s)
+            kv.pull(k, val)
+            check_diff(val, 0)
+        for k, s in compr_init_keys_shapes:
+            # tests that GC is not used for init of a key
+            val = mx.nd.zeros(s)
+            kv.pull(k, val)
+            check_diff(val, 1)
 
-    def check_compr_zero(kv):
-        for k,s in [('1121', shape),('112221',irregular_shape),('11221', 
big_shape)]:
+    def check_compr_zero():
+        for k,s in compr_keys_shapes:
             kv.push(k, mx.nd.zeros(s))
             # to check that all are set to 0s
             val = mx.nd.ones(s)
             kv.pull(k, val)
-            check_diff_to_scalar(val, 0)
+            check_diff(val, 0)
 
-    def check_compr_random(kv, threshold, nworker):
+    def check_compr_random(threshold, nrepeat):
         # set a seed so all workers generate same data. knowing this helps
         # calculate expected value after pull
         mx.random.seed(123)
         rnd.seed(123)
-        nrepeat = 5
-        compr_random_keys_shapes = [('2121', 
shape),('212221',irregular_shape),('21221', big_shape)]
+
         # use new keys so residual is 0 for calculation of expected
         for k,s in compr_random_keys_shapes:
             kv.init(k, mx.nd.zeros(s))
@@ -278,39 +319,51 @@ def check_compr_random(kv, threshold, nworker):
                 decompr *= nworker * rate
                 assert_almost_equal(diff.asnumpy(), decompr)
 
-    print ('worker '+str(my_rank)+' started with non compression tests')
-    check_default_keys(kv, my_rank, nworker)
-    check_row_sparse_keys(kv, my_rank, nworker)
-    check_row_sparse_keys_with_zeros(kv, my_rank, nworker)
-    check_big_row_sparse_keys(kv, my_rank, nworker)
-    print('worker ' + str(my_rank) + ' is done with non compression tests')
-
-    # don't run non compressed keys after this as kvstore now is set to 
compressed
-    print ('worker '+str(my_rank)+' started with compression tests')
-    kv, threshold = init_kv_compressed(kv)
-    check_compr_pull_before_push(kv)
-    check_compr_zero(kv)
-    check_compr_residual(kv, threshold, nworker)
-    check_compr_ones(kv, threshold, nworker)
-    check_compr_random(kv, threshold, nworker)
+    print ('worker ' + str(my_rank) + ' started with compression tests')
+    check_compr_pull_before_push()
+    check_compr_zero()
+    check_compr_residual(threshold)
+    check_compr_ones(threshold)
+    check_compr_random(threshold, nrepeat)
     print('worker ' + str(my_rank) + ' is done with compression tests')
 
-def test_sync_init():
+def test_sync_init(gpu_tests=False):
+    def get_dtype(idx, cur_keys):
+        if idx < len(cur_keys)/2:
+            dtype = 'float32'
+        else:
+            dtype = 'float16'
+        return dtype
+
     def check_init(kv, cur_keys, cur_shape, device=False):
         ctx = mx.gpu(0) if device else mx.cpu()
-        val = [mx.nd.zeros(cur_shape, ctx) for i in cur_keys]
+        val = [mx.nd.zeros(cur_shape, ctx=ctx, dtype=get_dtype(i, cur_keys)) 
for i in range(len(cur_keys))]
         for i in range(len(cur_keys)):
             expected = i
-            kv.init(cur_keys[i], [mx.nd.ones(cur_shape, ctx) * i])
+            kv.init(cur_keys[i], [mx.nd.ones(cur_shape, ctx=ctx, 
dtype=get_dtype(i, cur_keys)) * i])
             kv.pull(cur_keys[i], out=val[i])
-            check_diff_to_scalar(val[i], expected)
+            check_diff(val[i], expected)
     check_init(kv, init_test_keys, shape)
     check_init(kv, init_test_keys_big, big_shape)
-    check_init(kv, init_test_keys_device, shape, device=True)
-    check_init(kv, init_test_keys_device_big, big_shape, device=True)
-    my_rank = kv.rank
-    print('worker ' + str(my_rank) + ' is initialized')
+    if gpu_tests:
+        check_init(kv, init_test_keys_device, shape, device=True)
+        check_init(kv, init_test_keys_device_big, big_shape, device=True)
+    print('worker ' + str(kv.rank) + ' is initialized')
 
 if __name__ == "__main__":
-    test_sync_init()
-    test_sync_push_pull()
+    parser = argparse.ArgumentParser(description='test distributed kvstore in 
dist_sync mode')
+    parser.add_argument('--nrepeat', type=int, default=7)
+    parser.add_argument('--type', type=str, default='all')
+    parser.add_argument('--no-gpu', dest='gpu', action='store_false')
+    parser.add_argument('--no-multiprecision', dest='multiprecision', 
action='store_false')
+    opt = parser.parse_args()
+    if opt.type == 'all' or  opt.type == 'init':
+        test_sync_init(opt.gpu)
+    kv = init_kv()
+    if opt.type == 'all' or  opt.type == 'default':
+        kv = set_optimizer(use_multiprecision=opt.multiprecision)
+        test_sync_push_pull(opt.nrepeat)
+    # dont run non compressed tests after this as kvstore compression will be 
set here
+    if opt.type == 'all' or  opt.type == 'compressed':
+        kv, threshold = init_kv_compressed(kv)
+        test_sync_2bit_compression(threshold, opt.nrepeat)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to