This is an automated email from the ASF dual-hosted git repository. haibin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 33226ab Remove kvstore calls from FM example (#10946) 33226ab is described below commit 33226abe767b1311942b7e1805557838a2aa1d36 Author: Haibin Lin <linhaibin.e...@gmail.com> AuthorDate: Thu May 17 09:45:26 2018 -0700 Remove kvstore calls from FM example (#10946) * Remove kvstore calls from FM example * fix typo --- example/sparse/factorization_machine/train.py | 24 +++++++++++++----------- example/sparse/linear_classification/train.py | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/example/sparse/factorization_machine/train.py b/example/sparse/factorization_machine/train.py index 741cf95..af3d60b 100644 --- a/example/sparse/factorization_machine/train.py +++ b/example/sparse/factorization_machine/train.py @@ -58,6 +58,7 @@ parser.add_argument('--log-interval', type=int, default=100, parser.add_argument('--kvstore', type=str, default='local', help='what kvstore to use', choices=["dist_async", "local"]) + if __name__ == '__main__': import logging head = '%(asctime)-15s %(message)s' @@ -75,6 +76,16 @@ if __name__ == '__main__': assert(args.data_train is not None and args.data_test is not None), \ "dataset for training or test is missing" + def batch_row_ids(data_batch): + """ Generate row ids based on the current mini-batch """ + idx = data_batch.data[0].indices + return {'w': idx, 'v': idx} + + def all_row_ids(data_batch): + """ Generate row ids for all rows """ + all_rows = mx.nd.arange(0, num_features, dtype='int64') + return {'w': all_rows, 'v': all_rows} + # create kvstore kv = mx.kvstore.create(kvstore) # data iterator @@ -102,12 +113,6 @@ if __name__ == '__main__': metric = mx.metric.create(['log_loss']) speedometer = mx.callback.Speedometer(batch_size, log_interval) - # get the sparse weight parameter - w_index = mod._exec_group.param_names.index('w') - w_param = mod._exec_group.param_arrays[w_index] - v_index = mod._exec_group.param_names.index('v') - v_param = mod._exec_group.param_arrays[v_index] - logging.info('Training started ...') train_iter = iter(train_data) eval_iter = iter(eval_data) @@ -118,9 +123,7 @@ if __name__ == '__main__': nbatch += 1 # manually pull sparse weights from kvstore so that _square_sum # only computes the rows necessary - row_ids = batch.data[0].indices - kv.row_sparse_pull('w', w_param, row_ids=[row_ids], priority=-w_index) - kv.row_sparse_pull('v', v_param, row_ids=[row_ids], priority=-v_index) + mod.prepare(batch, sparse_row_id_fn=batch_row_ids) mod.forward_backward(batch) # update all parameters (including the weight parameter) mod.update() @@ -131,8 +134,7 @@ if __name__ == '__main__': speedometer(speedometer_param) # pull all updated rows before validation - kv.row_sparse_pull('w', w_param, row_ids=[row_ids], priority=-w_index) - kv.row_sparse_pull('v', v_param, row_ids=[row_ids], priority=-v_index) + mod.prepare(None, all_row_ids) # evaluate metric on validation dataset score = mod.score(eval_iter, ['log_loss']) logging.info("epoch %d, eval log loss = %s" % (epoch, score[0][1])) diff --git a/example/sparse/linear_classification/train.py b/example/sparse/linear_classification/train.py index cde40dd..4d60efb 100644 --- a/example/sparse/linear_classification/train.py +++ b/example/sparse/linear_classification/train.py @@ -46,7 +46,7 @@ AVAZU = { def batch_row_ids(data_batch): """ Generate row ids based on the current mini-batch """ - return {'weight': batch.data[0].indices} + return {'weight': data_batch.data[0].indices} def all_row_ids(data_batch): """ Generate row ids for all rows """ -- To stop receiving notification emails like this one, please contact hai...@apache.org.