This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 33226ab  Remove kvstore calls from FM example (#10946)
33226ab is described below

commit 33226abe767b1311942b7e1805557838a2aa1d36
Author: Haibin Lin <linhaibin.e...@gmail.com>
AuthorDate: Thu May 17 09:45:26 2018 -0700

    Remove kvstore calls from FM example (#10946)
    
    * Remove kvstore calls from FM example
    
    * fix typo
---
 example/sparse/factorization_machine/train.py | 24 +++++++++++++-----------
 example/sparse/linear_classification/train.py |  2 +-
 2 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/example/sparse/factorization_machine/train.py 
b/example/sparse/factorization_machine/train.py
index 741cf95..af3d60b 100644
--- a/example/sparse/factorization_machine/train.py
+++ b/example/sparse/factorization_machine/train.py
@@ -58,6 +58,7 @@ parser.add_argument('--log-interval', type=int, default=100,
 parser.add_argument('--kvstore', type=str, default='local',
                     help='what kvstore to use', choices=["dist_async", 
"local"])
 
+
 if __name__ == '__main__':
     import logging
     head = '%(asctime)-15s %(message)s'
@@ -75,6 +76,16 @@ if __name__ == '__main__':
     assert(args.data_train is not None and args.data_test is not None), \
           "dataset for training or test is missing"
 
+    def batch_row_ids(data_batch):
+        """ Generate row ids based on the current mini-batch """
+        idx = data_batch.data[0].indices
+        return {'w': idx, 'v': idx}
+
+    def all_row_ids(data_batch):
+        """ Generate row ids for all rows """
+        all_rows = mx.nd.arange(0, num_features, dtype='int64')
+        return {'w': all_rows, 'v': all_rows}
+
     # create kvstore
     kv = mx.kvstore.create(kvstore)
     # data iterator
@@ -102,12 +113,6 @@ if __name__ == '__main__':
     metric = mx.metric.create(['log_loss'])
     speedometer = mx.callback.Speedometer(batch_size, log_interval)
 
-    # get the sparse weight parameter
-    w_index = mod._exec_group.param_names.index('w')
-    w_param = mod._exec_group.param_arrays[w_index]
-    v_index = mod._exec_group.param_names.index('v')
-    v_param = mod._exec_group.param_arrays[v_index]
-
     logging.info('Training started ...')
     train_iter = iter(train_data)
     eval_iter = iter(eval_data)
@@ -118,9 +123,7 @@ if __name__ == '__main__':
             nbatch += 1
             # manually pull sparse weights from kvstore so that _square_sum
             # only computes the rows necessary
-            row_ids = batch.data[0].indices
-            kv.row_sparse_pull('w', w_param, row_ids=[row_ids], 
priority=-w_index)
-            kv.row_sparse_pull('v', v_param, row_ids=[row_ids], 
priority=-v_index)
+            mod.prepare(batch, sparse_row_id_fn=batch_row_ids)
             mod.forward_backward(batch)
             # update all parameters (including the weight parameter)
             mod.update()
@@ -131,8 +134,7 @@ if __name__ == '__main__':
             speedometer(speedometer_param)
 
         # pull all updated rows before validation
-        kv.row_sparse_pull('w', w_param, row_ids=[row_ids], priority=-w_index)
-        kv.row_sparse_pull('v', v_param, row_ids=[row_ids], priority=-v_index)
+        mod.prepare(None, all_row_ids)
         # evaluate metric on validation dataset
         score = mod.score(eval_iter, ['log_loss'])
         logging.info("epoch %d, eval log loss = %s" % (epoch, score[0][1]))
diff --git a/example/sparse/linear_classification/train.py 
b/example/sparse/linear_classification/train.py
index cde40dd..4d60efb 100644
--- a/example/sparse/linear_classification/train.py
+++ b/example/sparse/linear_classification/train.py
@@ -46,7 +46,7 @@ AVAZU = {
 
 def batch_row_ids(data_batch):
     """ Generate row ids based on the current mini-batch """
-    return {'weight': batch.data[0].indices}
+    return {'weight': data_batch.data[0].indices}
 
 def all_row_ids(data_batch):
     """ Generate row ids for all rows """

-- 
To stop receiving notification emails like this one, please contact
hai...@apache.org.

Reply via email to