EBernhardson has uploaded a new change for review. (
https://gerrit.wikimedia.org/r/391728 )
Change subject: Speed up DBN evaluation.
......................................................................
Speed up DBN evaluation.
The toDF() call in dbn.py causes us to evaluate one partition on its own
for spark to figure out what the field types are. Later spark will
evaluate the other 199 partitions. On a a test with a dataframe
containing enwiki and dewiki a single partition can take up to 15
minutes. Save this by defining the schema explicitly instead of making
spark figure it out.
15 minutes is also a long time for a single partition to run. Use a
heuristic to increase the number of partitions from 200 up to 2000 when
we have more data. In tests this patch cut the total dbn time from 23
minutes to 8.
Change-Id: I14d663f49a54b7bd130186aebfbeffde1e1a6d82
---
M mjolnir/dbn.py
M mjolnir/utilities/data_pipeline.py
2 files changed, 19 insertions(+), 6 deletions(-)
git pull ssh://gerrit.wikimedia.org:29418/search/MjoLniR
refs/changes/28/391728/1
diff --git a/mjolnir/dbn.py b/mjolnir/dbn.py
index 536064e..8c4ab76 100644
--- a/mjolnir/dbn.py
+++ b/mjolnir/dbn.py
@@ -9,6 +9,7 @@
import json
import pyspark.sql
from pyspark.sql import functions as F
+from pyspark.sql import types as T
import mjolnir.spark
@@ -179,7 +180,7 @@
model.train(sessions)
return _extract_labels_from_dbn(model, reader)
- return (
+ rdd_rel = (
df
# group and collect up the hits for individual (wikiid, norm_query_id,
# session_id) tuples to match how the dbn expects to receive data.
@@ -192,7 +193,14 @@
# of grouping into python, but that could just as well end up worse?
.repartition(num_partitions, 'wikiid', 'norm_query_id')
# Run each partition through the DBN to generate relevance scores.
- .rdd.mapPartitions(train_partition)
- # Convert the rdd of tuples back into a DataFrame so the fields all
- # have a name.
- .toDF(['wikiid', 'norm_query_id', 'hit_page_id', 'relevance']))
+ .rdd.mapPartitions(train_partition))
+
+ # Using toDF() is very slow as it has to run some of the partitions to
check their
+ # types, and then run all the partitions later to get the actual data. To
prevent
+ # running twice specify the schema we expect.
+ return df.sql_ctx.createDataFrame(rdd_rel, T.StructType([
+ T.StructField('wikiid', T.StringType(), False),
+ T.StructField('norm_query_id', T.LongType(), False),
+ T.StructField('hit_page_id', T.LongType(), False),
+ T.StructField('relevance', T.DoubleType(), False)
+ ]))
diff --git a/mjolnir/utilities/data_pipeline.py
b/mjolnir/utilities/data_pipeline.py
index a5c37d1..c8e676f 100644
--- a/mjolnir/utilities/data_pipeline.py
+++ b/mjolnir/utilities/data_pipeline.py
@@ -85,9 +85,14 @@
print 'Fetched a total of %d samples for %d wikis' % (nb_samples,
len(wikis))
df_norm.unpersist()
+ # Target around 125k rows per partition. Note that this isn't
+ # how many the dbn will see, because it gets collected up. Just
+ # a rough guess.
+ dbn_partitions = int(max(200, min(2000, nb_samples / 125000 ) ))
+
# Learn relevances
df_rel = (
- mjolnir.dbn.train(df_sampled, {
+ mjolnir.dbn.train(df_sampled, num_partitions=dbn_partitions,
dbn_config={
'MAX_ITERATIONS': 40,
'DEBUG': False,
'PRETTY_LOG': True,
--
To view, visit https://gerrit.wikimedia.org/r/391728
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings
Gerrit-MessageType: newchange
Gerrit-Change-Id: I14d663f49a54b7bd130186aebfbeffde1e1a6d82
Gerrit-PatchSet: 1
Gerrit-Project: search/MjoLniR
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <[email protected]>
_______________________________________________
MediaWiki-commits mailing list
[email protected]
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits