piiswrong closed pull request #8345: Misc fixes for sparse distributed training
URL: https://github.com/apache/incubator-mxnet/pull/8345
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/sparse/linear_classification.py 
b/example/sparse/linear_classification.py
index b173d04139..70f896386c 100644
--- a/example/sparse/linear_classification.py
+++ b/example/sparse/linear_classification.py
@@ -96,6 +96,7 @@
     # get the sparse weight parameter
     weight_index = mod._exec_group.param_names.index('weight')
     weight_param = mod._exec_group.param_arrays[weight_index]
+    all_row_ids = mx.nd.arange(0, num_features, dtype='int64')
     speedometer = mx.callback.Speedometer(batch_size, 100)
 
     logging.info('Training started ...')
@@ -118,9 +119,15 @@
             speedometer_param = mx.model.BatchEndParam(epoch=epoch, 
nbatch=nbatch,
                                                        eval_metric=metric, 
locals=locals())
             speedometer(speedometer_param)
+        # pull all rows before making a checkpoint
+        if kv:
+            kv.row_sparse_pull('weight', weight_param, row_ids=[all_row_ids],
+                               priority=-weight_index)
         # evaluate metric on validation dataset
         score = mod.score(eval_data, ['nll_loss'])
         logging.info('epoch %d, eval nll = %s ' % (epoch, score[0][1]))
+        save_optimizer_states = 'dist' not in kv.type
+        mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=False)
         # reset the iterator for next pass of data
         data_iter.reset()
     logging.info('Training completed.')
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 2d5e52fc3a..5e62be8c4c 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -42,10 +42,6 @@ namespace kvstore {
 /**
  * \brief distributed kvstore
  *
- * for a worker node, it always guarantees that all push and pull issued from
- * this worker on the same key are serialized. namely push(3) and then pull(3),
- * then the data pulled is always containing the modification from the push(3).
- *
  * it's the server node's job to control the data consistency among all
  * workers. see details on \ref ServerHandle::Start
  */
@@ -248,7 +244,7 @@ class KVStoreDist : public KVStoreLocal {
         LOG(FATAL) << "RowSparsePull with multiple values is not implemented 
yet";
       } else {
         auto& indices = target_val_rowids[0].second;
-        PullRowSparse_(key, &recv_buf, indices, priority);
+        PullRowSparse_(key, recv_buf, indices, priority);
         comm_->BroadcastRowSparse(key, recv_buf, grouped_val_rowid, num_vals 
== 1, priority);
       }
     }
@@ -322,24 +318,24 @@ class KVStoreDist : public KVStoreLocal {
   }
 
   // pull row sparse weight into `recv_buf` based on indices given by `indices`
-  void PullRowSparse_(const int key, NDArray *recv_buf, const NDArray& 
indices, int priority) {
+  void PullRowSparse_(const int key, const NDArray& recv_buf,
+                      const NDArray& indices, int priority) {
     using namespace rowsparse;
     auto pull_from_servers = [this, key, recv_buf, indices]
                              (RunContext rctx, Engine::CallbackOnComplete cb) {
       // allocate memory for the buffer
       size_t num_rows = indices.shape().Size();
-      recv_buf->CheckAndAlloc({mshadow::Shape1(num_rows)});
+      recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
 #if MKL_EXPERIMENTAL == 1
-      mkl_set_tblob_eager_mode(recv_buf->data());
+      mkl_set_tblob_eager_mode(recv_buf.data());
 #endif
-      real_t* data = recv_buf->data().dptr<real_t>();
-      auto indices_data = indices.data();
-      const auto offsets = indices_data.dptr<int64_t>();
-      const auto unit_len = recv_buf->shape().ProdShape(1, 
recv_buf->shape().ndim());
+      real_t* data = recv_buf.data().dptr<real_t>();
+      const auto offsets = indices.data().dptr<int64_t>();
+      const auto unit_len = recv_buf.shape().ProdShape(1, 
recv_buf.shape().ndim());
       const int64_t size = num_rows * unit_len;
        // convert to ps keys in row sparse format
       PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
-                                      unit_len, recv_buf->shape()[0]);
+                                      unit_len, recv_buf.shape()[0]);
       if (this->log_verbose_) {
         LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << 
" keys: "
                   << pskv.keys << " size: " << size;
@@ -348,8 +344,8 @@ class KVStoreDist : public KVStoreLocal {
       // copy indices to recv_buf. this needs to be done before ZPull
       // because after pull is done, the callback function returns and locks 
are released.
       // at this point, later functions may access the indices variable while 
copy happens
-      mshadow::Copy(recv_buf->aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
-                    indices_data.FlatTo1D<cpu, int64_t>());
+      mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
+                    indices.data().FlatTo1D<cpu, int64_t>());
       CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, 
kRowSparsePushPull,
         [vals, cb]() { delete vals; cb(); });
     };
@@ -357,7 +353,7 @@ class KVStoreDist : public KVStoreLocal {
         pull_from_servers,
         pinned_ctx_,
         {indices.var()},
-        {recv_buf->var()},
+        {recv_buf.var()},
         FnProperty::kNormal,
         priority,
         PROFILER_MESSAGE("KVStoreDistRowSparsePull"));
@@ -366,15 +362,14 @@ class KVStoreDist : public KVStoreLocal {
   // push row sparse gradient
   void PushRowSparse(int key, const NDArray &send_buf, int priority) {
     using namespace rowsparse;
-    auto push_to_servers = [this, key, &send_buf]
+    auto push_to_servers = [this, key, send_buf]
                            (RunContext rctx, Engine::CallbackOnComplete cb) {
 #if MKL_EXPERIMENTAL == 1
       mkl_set_tblob_eager_mode(send_buf.data());
 #endif
       real_t* data = send_buf.data().dptr<real_t>();
-      bool init = send_buf.storage_initialized();
-      const int64_t num_rows = init ? send_buf.aux_shape(kIdx)[0] : 0;
-      const auto offsets = init ? send_buf.aux_data(kIdx).dptr<int64_t>() : 
nullptr;
+      const int64_t num_rows = send_buf.aux_shape(kIdx)[0];
+      const auto offsets = send_buf.aux_data(kIdx).dptr<int64_t>();
       const auto unit_len = send_buf.shape().ProdShape(1, 
send_buf.shape().ndim());
       const int64_t size = num_rows * unit_len;
 
@@ -472,7 +467,7 @@ class KVStoreDist : public KVStoreLocal {
     return pskv;
   }
 
-  // TODO(haibin) this encoding method for row sparse keys doesn't allow 
cross-layer batching
+  // Note: this encoding method for row sparse keys doesn't allow cross-layer 
batching
   inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const 
int64_t num_rows,
                                   const int64_t *offsets, const size_t 
unit_len,
                                   const int64_t total_num_rows) {
@@ -495,15 +490,15 @@ class KVStoreDist : public KVStoreLocal {
         ps::Key master_key = krs[i].begin() + key;
         pskv.keys.push_back(master_key);
         pskv.lens.push_back(0);
-        if (offsets) {
+        if (offsets && size > 0) {
           // calculate partition ranges
           int64_t part_num_rows =
             llround(static_cast<double>(total_num_rows) / num_servers * (i + 
1)) -
             llround(static_cast<double>(total_num_rows) / num_servers * i);
           auto end_row = start_row + part_num_rows;
+          // search for offsets in [start_row, end_row)
           auto lb = std::lower_bound(offsets, offsets + num_rows, start_row);
           auto ub = std::upper_bound(offsets, offsets + num_rows, end_row - 1);
-
           for (auto offset = lb; offset < ub; offset++) {
             ps::Key ps_key = krs[i].begin() + key + (*offset - start_row);
             CHECK_LT(ps_key, krs[i].end());
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 97bda906c6..ea4243ee91 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -93,6 +93,7 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
     .add_enum("float16", mshadow::kFloat16)
     .add_enum("uint8", mshadow::kUint8)
     .add_enum("int32", mshadow::kInt32)
+    .add_enum("int64", mshadow::kInt64)
     .describe("Target data type.");
   }
 };
@@ -179,6 +180,13 @@ void FillCompute(const nnvm::NodeAttrs& attrs,
   });
 }
 
+struct PopulateFullIdxRspKernel {
+  template<typename IType>
+  MSHADOW_XINLINE static void Map(int i, IType* out) {
+    KERNEL_ASSIGN(out[i], kWriteTo, i);
+  }
+};
+
 // Fill in the indices and values of a RowSparse NDArray to represent a zeros 
NDArray,
 // instead of the usual compact representation.
 template<typename xpu>
@@ -192,21 +200,14 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, 
NDArray *dst) {
     MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
       auto num_rows = dst->shape()[0];
       dst->CheckAndAlloc({Shape1(num_rows)});
-      auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
+      auto idx = dst->aux_data(kIdx);
       auto val = dst->data();
       Kernel<set_zero, xpu>::Launch(s, val.Size(), val.dptr<DType>());
-      ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1));
+      Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, num_rows, 
idx.dptr<IType>());
     });
   });
 }
 
-struct PopulateFullIdxRspKernel {
-  template<typename IType>
-  MSHADOW_XINLINE static void Map(int i, IType* out) {
-    KERNEL_ASSIGN(out[i], kWriteTo, i);
-  }
-};
-
 // Fill full indices NDArray with zeros by updating the aux shape.
 template<typename xpu>
 void PopulateFullIdxRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
diff --git a/tests/nightly/dist_sync_kvstore.py 
b/tests/nightly/dist_sync_kvstore.py
index 5f1b11f041..900d6bb6af 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -39,7 +39,7 @@ def check_diff_to_scalar(A, x, rank=None):
 
 rate = 2
 shape = (2, 3)
-big_shape = (1200, 1200)        # bigger than BIGARRAY_BOUND
+big_shape = (1200, 1200)        # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
 
 kv = mx.kv.create('dist_sync')
 
@@ -104,24 +104,27 @@ def check_row_sparse_keys(kv, my_rank, nworker):
     def check_row_sparse_keys_with_zeros(kv, my_rank, nworker):
         nrepeat = 3
         # prepare gradient
-        v = mx.nd.zeros(shape)
-        big_v = mx.nd.zeros(big_shape)
+        v = mx.nd.sparse.zeros('row_sparse', shape)
+        big_v = mx.nd.sparse.zeros('row_sparse', big_shape)
         # push
         for i in range(nrepeat):
-            kv.push('11', v.tostype('row_sparse'))
-            kv.push('100', big_v.tostype('row_sparse'))
-
+            kv.push('11', v)
+            kv.push('100', big_v)
             # pull a subset of rows this worker is interested in
             all_row_ids = np.arange(shape[0])
-            val = mx.nd.ones(shape).tostype('row_sparse')
-            big_val = mx.nd.ones(big_shape).tostype('row_sparse')
-            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids, 
dtype='int64'))
-            big_num_rows = shape[0]
+            val = mx.nd.sparse.zeros('row_sparse', shape)
+            big_val = mx.nd.sparse.zeros('row_sparse', big_shape)
+            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids))
             big_all_row_ids = np.arange(big_shape[0])
-            kv.row_sparse_pull('100', out=big_val, 
row_ids=mx.nd.array(big_all_row_ids, dtype='int64'))
+            kv.row_sparse_pull('100', out=big_val, 
row_ids=mx.nd.array(big_all_row_ids))
             # verify results
-            check_diff_to_scalar(val, mx.nd.ones(shape))
-            check_diff_to_scalar(big_val, mx.nd.ones(big_shape))
+            check_diff_to_scalar(val, 1)
+            check_diff_to_scalar(big_val, 1)
+            # pull empty weights
+            kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array([]))
+            kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array([]))
+            check_diff_to_scalar(val, 0)
+            check_diff_to_scalar(big_val, 0)
 
     def check_big_row_sparse_keys(kv, my_rank, nworker):
         mx.random.seed(123)
@@ -154,7 +157,7 @@ def check_big_row_sparse_keys(kv, my_rank, nworker):
             rnd.seed(my_rank)
             num_rows = big_shape[0]
             row_ids_np = np.random.randint(num_rows, size=num_rows)
-            row_ids = mx.nd.array(row_ids_np, dtype='int64')
+            row_ids = mx.nd.array(row_ids_np)
             # perform pull
             val = mx.nd.zeros(big_shape, stype='row_sparse')
             kv.row_sparse_pull('100', out=val, row_ids=row_ids)
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index 576d963540..fc8c350bbb 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -734,6 +734,8 @@ def test_output():
     assert_almost_equal(out.asnumpy(), zeros.asnumpy())
     mx.nd.full(shape, 2, out=out)
     assert_almost_equal(out.asnumpy(), ones.asnumpy() * 2)
+    arange_out = mx.nd.arange(0, 20, dtype='int64')
+    assert_almost_equal(arange_out.asnumpy(), np.arange(0, 20))
 
 def test_ndarray_fluent():
     has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 
'sum', 'nansum', 'prod',
diff --git a/tests/python/unittest/test_optimizer.py 
b/tests/python/unittest/test_optimizer.py
index 8666b9e714..1a26434015 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -232,6 +232,10 @@ def test_sgd():
                                 if dtype != np.float16:
                                     compare_optimizer(opt1(**kwarg), 
opt2(**kwarg), shape[:2],
                                                       dtype, w_stype='csr', 
g_stype='csr')
+    # test optimizer with a big shape
+    big_shape = (54686454, 1)
+    kwarg = {'momentum': 0.9, 'wd': 0.05}
+    compare_optimizer(opt1(**kwarg), opt2(**kwarg), big_shape, np.float32)
 
 class PySparseSGD(mx.optimizer.Optimizer):
     """python reference implemenation of sgd"""


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to