jenkins-bot has submitted this change and it was merged. ( 
https://gerrit.wikimedia.org/r/391728 )

Change subject: Speed up DBN evaluation.
......................................................................


Speed up DBN evaluation.

The toDF() call in dbn.py causes us to evaluate one partition on its own
for spark to figure out what the field types are. Later spark will
evaluate the other 199 partitions. On a a test with a dataframe
containing enwiki and dewiki a single partition can take up to 15
minutes. Save this by defining the schema explicitly instead of making
spark figure it out.

15 minutes is also a long time for a single partition to run. Use a
heuristic to increase the number of partitions from 200 up to 2000 when
we have more data. In tests this patch cut the total dbn time from 23
minutes to 8.

Change-Id: I14d663f49a54b7bd130186aebfbeffde1e1a6d82
---
M mjolnir/dbn.py
M mjolnir/utilities/data_pipeline.py
2 files changed, 19 insertions(+), 6 deletions(-)

Approvals:
  jenkins-bot: Verified
  DCausse: Looks good to me, approved



diff --git a/mjolnir/dbn.py b/mjolnir/dbn.py
index 536064e..8c4ab76 100644
--- a/mjolnir/dbn.py
+++ b/mjolnir/dbn.py
@@ -9,6 +9,7 @@
 import json
 import pyspark.sql
 from pyspark.sql import functions as F
+from pyspark.sql import types as T
 import mjolnir.spark
 
 
@@ -179,7 +180,7 @@
         model.train(sessions)
         return _extract_labels_from_dbn(model, reader)
 
-    return (
+    rdd_rel = (
         df
         # group and collect up the hits for individual (wikiid, norm_query_id,
         # session_id) tuples to match how the dbn expects to receive data.
@@ -192,7 +193,14 @@
         # of grouping into python, but that could just as well end up worse?
         .repartition(num_partitions, 'wikiid', 'norm_query_id')
         # Run each partition through the DBN to generate relevance scores.
-        .rdd.mapPartitions(train_partition)
-        # Convert the rdd of tuples back into a DataFrame so the fields all
-        # have a name.
-        .toDF(['wikiid', 'norm_query_id', 'hit_page_id', 'relevance']))
+        .rdd.mapPartitions(train_partition))
+
+    # Using toDF() is very slow as it has to run some of the partitions to 
check their
+    # types, and then run all the partitions later to get the actual data. To 
prevent
+    # running twice specify the schema we expect.
+    return df.sql_ctx.createDataFrame(rdd_rel, T.StructType([
+        T.StructField('wikiid', T.StringType(), False),
+        T.StructField('norm_query_id', T.LongType(), False),
+        T.StructField('hit_page_id', T.LongType(), False),
+        T.StructField('relevance', T.DoubleType(), False)
+    ]))
diff --git a/mjolnir/utilities/data_pipeline.py 
b/mjolnir/utilities/data_pipeline.py
index a5c37d1..401be45 100644
--- a/mjolnir/utilities/data_pipeline.py
+++ b/mjolnir/utilities/data_pipeline.py
@@ -85,9 +85,14 @@
     print 'Fetched a total of %d samples for %d wikis' % (nb_samples, 
len(wikis))
     df_norm.unpersist()
 
+    # Target around 125k rows per partition. Note that this isn't
+    # how many the dbn will see, because it gets collected up. Just
+    # a rough guess.
+    dbn_partitions = int(max(200, min(2000, nb_samples / 125000)))
+
     # Learn relevances
     df_rel = (
-        mjolnir.dbn.train(df_sampled, {
+        mjolnir.dbn.train(df_sampled, num_partitions=dbn_partitions, 
dbn_config={
             'MAX_ITERATIONS': 40,
             'DEBUG': False,
             'PRETTY_LOG': True,

-- 
To view, visit https://gerrit.wikimedia.org/r/391728
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: merged
Gerrit-Change-Id: I14d663f49a54b7bd130186aebfbeffde1e1a6d82
Gerrit-PatchSet: 2
Gerrit-Project: search/MjoLniR
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <[email protected]>
Gerrit-Reviewer: DCausse <[email protected]>
Gerrit-Reviewer: jenkins-bot <>

_______________________________________________
MediaWiki-commits mailing list
[email protected]
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to