EBernhardson has uploaded a new change for review. ( 
https://gerrit.wikimedia.org/r/403336 )

Change subject: [WIP] distributed training for lightgbm
......................................................................

[WIP] distributed training for lightgbm

untested. The daemon never closes right

Change-Id: Id50f4f53b221003a89555e870bb771ba26faad21
---
M mjolnir/training/lightgbm.py
1 file changed, 267 insertions(+), 81 deletions(-)


  git pull ssh://gerrit.wikimedia.org:29418/search/MjoLniR 
refs/changes/36/403336/1

diff --git a/mjolnir/training/lightgbm.py b/mjolnir/training/lightgbm.py
index cbc8883..460740b 100644
--- a/mjolnir/training/lightgbm.py
+++ b/mjolnir/training/lightgbm.py
@@ -9,7 +9,11 @@
 from mjolnir.utils import as_local_paths
 from multiprocessing.dummy import Pool
 import numpy as np
+import Pyro4
 import pyspark
+import socket
+import threading
+import time
 
 
 def _overrideParamsAccordingToTaskCpus(sc, params):
@@ -33,7 +37,9 @@
                 ds._free_handle()
 
 
-def build_distributed_boosters(rdd, params, train_matrix):
+def build_distributed_boosters(rdd, params, train_matrix, client):
+    num_partitions = rdd.getNumPartitions()
+
     def build_partition(rows):
         fold = rows.next()
         try:
@@ -47,7 +53,11 @@
             num_rounds = params['num_rounds']
             del params['num_rounds']
 
-        # TODO: Generalize
+        if client is not None:
+            machines, listen_port = client.request_machine_list(num_partitions)
+            params['machines'] = machines
+            params['local_listen_port'] = listen_port
+
         with load_datasets(fold) as datasets:
             eval_results = {}
             gbm = lgb.train(
@@ -71,7 +81,7 @@
             params[k] = val_type(params[k])
 
 
-def train(fold, paramOverrides, train_matrix=None):
+def train(fold, paramOverrides, train_matrix=None, client=None):
     sc = pyspark.SparkContext.getOrCreate()
     params = {
         'boosting_type': 'gbdt',
@@ -95,13 +105,15 @@
     if (len(fold) > 1):
         rdd = sc.parallelize(list(enumerate(fold)), 1).partitionBy(len(fold), 
lambda x: x).map(lambda x: x[1])
         raise Exception("TODO: Distributed Training")
+        if client is None:
+            raise Exception("client required for distributed training")
     else:
         rdd = sc.parallelize(fold, 1)
 
     if train_matrix is None:
         train_matrix = "all" if "all" in fold else "train"
 
-    booster, metrics = build_distributed_boosters(rdd, params, 
train_matrix).collect()[0]
+    booster, metrics = build_distributed_boosters(rdd, params, train_matrix, 
client).collect()[0]
     return LightGBMModel(booster, metrics)
 
 
@@ -132,90 +144,264 @@
         self._booster.save_model(path)
 
 
+DAEMON_PORT = 6827
+
+
 def tune(folds, stats, train_matrix, num_cv_jobs=5, num_workers=5, 
initial_num_trees=100, final_num_trees=500):
     cv_pool = None
     if num_cv_jobs > 1:
         cv_pool = Pool(num_cv_jobs)
 
-    # Configure the trials pool large enough to keep cv_pool full
-    num_folds = len(folds)
-    num_workers = len(folds[0])
-    trials_pool_size = int(math.floor(num_cv_jobs / (num_workers * num_folds)))
-    if trials_pool_size > 1:
-        trials_pool = Pool(trials_pool_size)
-    else:
-        trials_pool = None
+    with Daemon(socket.gethostname(), DAEMON_PORT) as daemon:
+        while not daemon.ready:
+            time.sleep(1)
 
-    train_func = functools.partial(train, train_matrix=train_matrix)
+        # Configure the trials pool large enough to keep cv_pool full
+        num_folds = len(folds)
+        num_workers = len(folds[0])
+        trials_pool_size = int(math.floor(num_cv_jobs / (num_workers * 
num_folds)))
+        if trials_pool_size > 1:
+            trials_pool = Pool(trials_pool_size)
+        else:
+            trials_pool = None
 
-    def eval_space(space, max_evals):
-        max_evals = 2  # TODO: remove
-        best, trials = mjolnir.training.hyperopt.minimize(
-            folds, train_func, space, max_evals=max_evals,
-            cv_pool=cv_pool, trials_pool=trials_pool)
-        for k, v in space.items():
-            if not np.isscalar(v):
-                print 'best %s: %f' % (k, best[k])
-        return best, trials
+        kwargs = {'train_matrix': train_matrix}
+        if num_workers > 1:
+            kwargs['client'] = Client(daemon.url)
+        train_func = functools.partial(train, **kwargs)
 
-    space = {
-        'boosting_type': 'gbdt',
-        'objective': 'lambdarank',
-        'metric': 'ndcg',
-        'ndcg_eval_at': '1,3,10',
-        'is_training_metric': True,
-        'num_rounds': initial_num_trees,
-        'max_bin': 255,
-        'num_leaves': 63,
-        'learning_rate': 0.1,
-        'feature_fraction': 1.0,
-        'bagging_fraction': 0.9,
-        'bagging_freq': 1,
-    }
-    tune_spaces = [
-        ('initial', {
-            'iterations': 5,
-            'space': {
-                'learning_rate': hyperopt.hp.uniform('learning_rate', 0.1, 
0.4),
-                'num_leaves': hyperopt.hp.quniform('num_leaves', 60, 150, 10),
-                'min_data_in_leaf': hyperopt.hp.quniform('min_data_in_leaf', 
25, 200, 25),
-                'min_sum_hessian_in_leaf': 
hyperopt.hp.uniform('min_sum_hessian_in_leaf', 1.0, 10.0),
-                'feature_fraction': hyperopt.hp.uniform('feature_fraction', 
0.8, 1.0),
-                'bagging_fraction': hyperopt.hp.uniform('bagging_fraction', 
0.8, 1.0),
-            }
-        }),
-        ('trees', {
-            'iterations': 30,
-            'condition': lambda: final_num_trees is not None and 
final_num_trees != initial_num_trees,
-            'space': {
-                'num_rounds': final_num_trees,
-                'learning_rate': hyperopt.hp.uniform('learning_rate', 0.01, 
0.4),
-            }
-        }),
-    ]
+        def eval_space(space, max_evals):
+            max_evals = 2  # TODO: remove
+            best, trials = mjolnir.training.hyperopt.minimize(
+                folds, train_func, space, max_evals=max_evals,
+                cv_pool=cv_pool, trials_pool=trials_pool)
+            for k, v in space.items():
+                if not np.isscalar(v):
+                    print 'best %s: %f' % (k, best[k])
+            return best, trials
 
-    stages = []
-    for name, stage_params in tune_spaces:
-        if 'condition' in stage_params and not stage_params['condition']():
-            continue
-        tune_space = stage_params['space']
-        for name, param in tune_space.items():
-            space[name] = param
-        best, trials = eval_space(space, stage_params['iterations'])
-        for name, param in tune_space.items():
-            space[name] = best[name]
-        stages.append((name, trials))
-
-    trials_final = stages[-1][1]
-    best_trial = np.argmin(trials_final.losses())
-    loss = trials_final.losses()[best_trial]
-    true_loss = trials_final.results[best_trial].get('true_loss')
-
-    return {
-        'trials': dict(stages),
-        'params': space,
-        'metrics': {
-            'cv-test': -loss,
-            'cv-train': -loss + true_loss
+        space = {
+            'boosting_type': 'gbdt',
+            'objective': 'lambdarank',
+            'metric': 'ndcg',
+            'ndcg_eval_at': '1,3,10',
+            'is_training_metric': True,
+            'num_rounds': initial_num_trees,
+            'max_bin': 255,
+            'num_leaves': 63,
+            'learning_rate': 0.1,
+            'feature_fraction': 1.0,
+            'bagging_fraction': 0.9,
+            'bagging_freq': 1,
         }
-    }
+        tune_spaces = [
+            ('initial', {
+                'iterations': 5,
+                'space': {
+                    'learning_rate': hyperopt.hp.uniform('learning_rate', 0.1, 
0.4),
+                    'num_leaves': hyperopt.hp.quniform('num_leaves', 60, 150, 
10),
+                    'min_data_in_leaf': 
hyperopt.hp.quniform('min_data_in_leaf', 25, 200, 25),
+                    'min_sum_hessian_in_leaf': 
hyperopt.hp.uniform('min_sum_hessian_in_leaf', 1.0, 10.0),
+                    'feature_fraction': 
hyperopt.hp.uniform('feature_fraction', 0.8, 1.0),
+                    'bagging_fraction': 
hyperopt.hp.uniform('bagging_fraction', 0.8, 1.0),
+                }
+            }),
+            ('trees', {
+                'iterations': 30,
+                'condition': lambda: final_num_trees is not None and 
final_num_trees != initial_num_trees,
+                'space': {
+                    'num_rounds': final_num_trees,
+                    'learning_rate': hyperopt.hp.uniform('learning_rate', 
0.01, 0.4),
+                }
+            }),
+        ]
+
+        stages = []
+        for name, stage_params in tune_spaces:
+            if 'condition' in stage_params and not stage_params['condition']():
+                continue
+            tune_space = stage_params['space']
+            for name, param in tune_space.items():
+                space[name] = param
+            best, trials = eval_space(space, stage_params['iterations'])
+            for name, param in tune_space.items():
+                space[name] = best[name]
+            stages.append((name, trials))
+
+        trials_final = stages[-1][1]
+        best_trial = np.argmin(trials_final.losses())
+        loss = trials_final.losses()[best_trial]
+        true_loss = trials_final.results[best_trial].get('true_loss')
+
+        return {
+            'trials': dict(stages),
+            'params': space,
+            'metrics': {
+                'cv-test': -loss,
+                'cv-train': -loss + true_loss
+            }
+        }
+
+
+def convert_node_lgb2xgb(node, nodeid_gen, depth=0):
+    # TODO: Untested
+    if 'threshold' in node:
+        # split
+        assert node['decision_type'] == '<='
+        nodeid = next(nodeid_gen)
+        children = [convert_node_lgb2xgb(n, nodeid_gen, depth + 1) for n in 
[node['left_child'], node['right_child']]],
+        return {
+            'no': children[0].nodeid,
+            'missing': children[0].nodeid if node['default_left'] else 
children[1].nodeid,
+            'nodeid': nodeid,
+            'children': children,
+            'depth': depth,
+            'split': node['split_feature'],
+            'yes': children[1].nodeid,
+            'split_condition': node['threshold']
+        }
+    else:
+        # leaf
+        return {
+            'leaf': node['leaf_value'],
+            'nodeid': next(nodeid_gen)
+        }
+
+
+def convert_lgb2xgb(xgb):
+    """Convert and lgb json dump into xgboost format"""
+    nodeid_gen = (x for x in xrange(999999999))
+    for tree in lgb['tree_info']:
+        # TODO: lightgbm reports shrinkage, do we need it? shrinkage = 
tree['shrinkage']
+        yield convert_node_lgb2xgb(tree['tree_structure'], nodeid_gen)
+
+
+class PortAssigner(object):
+    def __init__(self):
+        self.stages = {}
+        self.next_port = {}
+        self.assignments = {}
+
+    @Pyro4.expose
+    def request_machine_list(self, stage_id, num_partitions, partition_id, 
hostname):
+        if stage_id not in self.stages:
+            self.stages[stage_id] = [None] * num_partitions
+        stage = self.stages[stage_id]
+        if len(stage) != num_partitions:
+            raise Exception("Mismatched partition counts %d != %d" % 
(len(stage), num_partitions))
+        if stage[partition_id] is None:
+            # new partition
+            stage[partition_id] = hostname
+        elif stage[partition_id] == hostname:
+            # duplicate call
+            pass
+        else:
+            # overwriting previous call. Not sure how we
+            # would interrupt and let everyone eles know
+            # there is a different machine involved.
+            raise Exception("Overwrite not implemented")
+        if all(stage):
+            if stage_id not in self.assignments:
+                self._assign_ports(stage_id)
+            return ("ok", self.assignments[stage_id])
+        else:
+            return ("wait", None)
+
+    STARTING_PORT = 13469
+
+    def _assign_ports(self, stage_id):
+        """Assign ports on remote hosts to a stage
+        This seems pretty fragile, i hope nothing else is assigning
+        from our port range...
+        """
+        assignment = []
+        for hostname in self.stages[stage_id]:
+            if hostname not in self.next_port:
+                self.next_port[hostname] = self.STARTING_PORT
+            port = self.next_port[hostname]
+            self.next_port[hostname] += 1
+            assignment.append((hostname, port))
+        self.assignments[stage_id] = assignment
+
+
+Pyro4.config.SERVERTYPE = 'multiplex'
+
+
+class Daemon(threading.Thread):
+    def __init__(self, host, port):
+        threading.Thread.__init__(self)
+        self.host = host
+        self.port = port
+        self.uri = None
+        self.ready = False
+
+    def run(self):
+        service = PortAssigner()
+        attempts = 0
+        while attempts < 20:
+            attempts += 1
+            try:
+                with Pyro4.Daemon(host=self.host, port=self.port) as daemon:
+                    self._daemon = daemon
+                    self.uri = daemon.register(service)
+                    self.ready = True
+                    daemon.requestLoop()
+                    break
+            except socket.error, e:
+                if e.errno == 98:  # Address already in use
+                    self.port += 1
+                    pass
+                raise e
+
+    def stop(self):
+        self._daemon.shutdown()
+
+    def __enter__(self):
+        self.start()
+        return self
+
+    def __exit__(self, type, value, tb):
+        self.stop()
+        self.join()
+
+
+class Client(object):
+    def __init__(self, uri):
+        self.uri = uri
+        self._connection = None
+
+    def __getstate__(self):
+        return self.uri
+
+    def __setstate__(self, state):
+        self.uri = state
+        self._connection = None
+
+    def connection(self):
+        if self._connection is None:
+            self._connection = Pyro4.Proxy(self.uri)
+        return self._connection
+
+    def request_machine_list(self, num_partitions, max_attempts=120):
+        context = pyspark.TaskContext.get()
+        if context is None:
+            raise Exception('Expecting to run in executor context')
+
+        stage_id = context.stageId()
+        partition_id = context.partitionId()
+        hostname = socket.gethostname()
+
+        attempts = 0
+        conn = self.connection()
+        while attempts < max_attempts:
+            attempts += 1
+            status, data = conn.request_machine_list(stage_id, num_partitions, 
partition_id, hostname)
+            if status == "ok":
+                local_port = data[partition_id][1]
+                machines = ["%s:%d" % pair for pair in data]
+                return machines, local_port
+            elif status == "wait":
+                time.sleep(1)
+            else:
+                raise Exception("Unexpected response: " + str(status))
+        raise Exception("Timed out waiting on machine list")

-- 
To view, visit https://gerrit.wikimedia.org/r/403336
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: newchange
Gerrit-Change-Id: Id50f4f53b221003a89555e870bb771ba26faad21
Gerrit-PatchSet: 1
Gerrit-Project: search/MjoLniR
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <[email protected]>

_______________________________________________
MediaWiki-commits mailing list
[email protected]
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to