ChaokunChang commented on a change in pull request #17555: [MXNET-#16795] Byteps-KVStore: Intergrate Byteps into mxnet as new type of kvstore backend URL: https://github.com/apache/incubator-mxnet/pull/17555#discussion_r376878037
########## File path: python/mxnet/kvstore/byteps.py ########## @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +""" BytePS backend for MXNet KVStore""" +from __future__ import absolute_import + +from ..ndarray import NDArray +from .base import KVStoreBase + +__all__ = ['BytePS'] + + [email protected] +class BytePS(KVStoreBase): + """BytePS backend for MXNet KVStore interface.""" + + def __init__(self): + """Initializes a new KVStore.""" + try: + import byteps.mxnet as bps + self.handle = bps + except ImportError as err: + print('Did not find BytePS library. Please install BytePS first') + raise err + self.handle.init() + + def broadcast(self, key, value, out, priority=0): + """ Broadcast the value NDArray at rank 0 to all ranks' out. If out is None, + the result is stored in `value`. + Parameters + ---------- + key : str, or int + The keys. + value : NDArray, or list of NDArray + Values corresponding to the key. + out : NDArray, or lise of NDArray + Values corresponding to the keys. + Examples + -------- + >>> # broadcast a single key-value pair + >>> shape = (2,3) + >>> kv = mx.kv.create('byteps') + >>> a = mx.nd.zeros(shape) + >>> kv.broadcast('3', mx.nd.ones(shape)*2, out=a) + >>> print a.asnumpy() + [[ 2. 2. 2.] + [ 2. 2. 2.]] + """ + + # do not accept list or tuple for key/value + assert isinstance(key, (str, int)) + + # unpack the list if it contains just one NDArray + value = value[0] if isinstance( + value, list) and len(value) == 1 else value + assert isinstance( + value, NDArray), "The type of value can only be NDArray or list of NDArray which has only one element." + + # for non-root-rank, assign value with 0, thus the result of pushpull will be + # equal to the value of root-rank, thus implementing broadcast. + root_rank = 0 + if self.rank != root_rank: + value.__imul__(0) + self.handle.byteps_push_pull(value, version=0, priority=priority, + name=str(key), is_average=False) + # Make sure tensors pushed to MXNet engine get processed such that all + # workers are synced before starting training. + value.wait_to_read() Review comment: It's needed. The test_pushpull will abort if `wait_to_read()` is removed. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
