EBernhardson has uploaded a new change for review. (
https://gerrit.wikimedia.org/r/391729 )
Change subject: Repair ability to collect data for undersized wikis
......................................................................
Repair ability to collect data for undersized wikis
When attempting to collect data for small wikis that have much less than
the provided samples_per_wiki in data_pipeline.py we would fail, because
the collected data was much less than expected. Rework this code to
allow for wikis that start with much less data than available.
While digging into this I realized that this check was being done much
too early. It was calculating against data that was not of the same
shape, so not the same counts, as the final data we feed into feature
collection. Everything between sampling and feature collection is
relatively cheap (compared to sending millions of queries to
elasticsearch) so move the check down to just before feature collection
where we know exactly how many observations we have.
Change-Id: Ib9f8d9b6204d7568e02356c1062cf3263d8eedd6
---
M mjolnir/sampling.py
M mjolnir/test/test_sampling.py
M mjolnir/utilities/data_pipeline.py
3 files changed, 48 insertions(+), 24 deletions(-)
git pull ssh://gerrit.wikimedia.org:29418/search/MjoLniR
refs/changes/29/391729/1
diff --git a/mjolnir/sampling.py b/mjolnir/sampling.py
index 50f527a..50722b5 100644
--- a/mjolnir/sampling.py
+++ b/mjolnir/sampling.py
@@ -165,15 +165,18 @@
.agg(F.sum('num_hit_page_ids').alias('num_hit_page_ids'))
.collect())
+ hit_page_id_counts = {row.wikiid: row.num_hit_page_ids for row in
hit_page_id_counts}
+
wiki_percents = {}
needs_sampling = False
- for row in hit_page_id_counts:
- wiki_percents[row.wikiid] = min(1., float(samples_per_wiki) /
row.num_hit_page_ids)
- if wiki_percents[row.wikiid] < 1.:
+
+ for wikiid, num_hit_page_ids in hit_page_id_counts.items():
+ wiki_percents[wikiid] = min(1., float(samples_per_wiki) /
num_hit_page_ids)
+ if wiki_percents[wikiid] < 1.:
needs_sampling = True
if not needs_sampling:
- return df
+ return hit_page_id_counts, df
# Aggregate down into a unique set of (wikiid, norm_query_id) and add in a
# count of the number of unique sessions per pair. We will sample
per-strata
@@ -187,11 +190,11 @@
# Spark will eventually throw this away in an LRU fashion.
.cache())
- # materialize df_queries_unique so we can unpersist the input df
- df_queries_unique.count()
- df.unpersist()
-
df_queries_sampled = _sample_queries(df_queries_unique, wiki_percents,
seed=seed)
# Select the rows chosen by sampling from the input df
- return df.join(df_queries_sampled, how='inner', on=['wikiid',
'norm_query_id'])
+ df_sampled = df.join(df_queries_sampled, how='inner', on=['wikiid',
'norm_query_id'])
+ df_sampled.cache().count()
+ df.unpersist()
+
+ return hit_page_id_counts, df_sampled
diff --git a/mjolnir/test/test_sampling.py b/mjolnir/test/test_sampling.py
index 2feeb29..66a4605 100644
--- a/mjolnir/test/test_sampling.py
+++ b/mjolnir/test/test_sampling.py
@@ -20,8 +20,9 @@
('foo', 'e', 5, 'eee', list(range(3))),
]).toDF(['wikiid', 'query', 'norm_query_id', 'session_id', 'hit_page_ids'])
- sampled = mjolnir.sampling.sample(df, samples_per_wiki=100,
- seed=12345).collect()
+ hit_page_id_counts, df_sampled = mjolnir.sampling.sample(
+ df, samples_per_wiki=100, seed=12345)
+ sampled = df_sampled.collect()
# The sampling rate should have been chosen as 1.0, so we should have all
data
# regardless of probabilities.
assert len(sampled) == 5
@@ -60,8 +61,8 @@
# Using a constant seed ensures deterministic testing. Because this code
# actually relies on the law of large numbers, and we do not have large
# numbers here, many seeds probably fail.
- df_sampled = mjolnir.sampling.sample(df, samples_per_wiki=samples_per_wiki,
- seed=12345)
+ hit_page_id_counts, df_sampled = mjolnir.sampling.sample(
+ df, samples_per_wiki=samples_per_wiki, seed=12345)
sampled = (
df_sampled
.select('wikiid', 'query',
F.explode('hit_page_ids').alias('hit_page_id'))
diff --git a/mjolnir/utilities/data_pipeline.py
b/mjolnir/utilities/data_pipeline.py
index c8e676f..3827e08 100644
--- a/mjolnir/utilities/data_pipeline.py
+++ b/mjolnir/utilities/data_pipeline.py
@@ -64,11 +64,17 @@
min_sessions_per_query=min_sessions_per_query)
# Sample to some subset of queries per wiki
+ hit_page_id_counts, df_sampled_raw = mjolnir.sampling.sample(
+ df_norm,
+ seed=54321,
+ samples_per_wiki=samples_per_wiki)
+
+ df_sampled_raw.count()
+ df_norm.unpersist()
+
+ # Transform our dataframe into the shape expected by the DBN
df_sampled = (
- mjolnir.sampling.sample(
- df_norm,
- seed=54321,
- samples_per_wiki=samples_per_wiki)
+ df_sampled_raw
# Explode source into a row per displayed hit
.select('*', F.expr("posexplode(hit_page_ids)").alias('hit_position',
'hit_page_id'))
.drop('hit_page_ids')
@@ -79,16 +85,12 @@
# materialize df_sampled and unpersist df_norm
nb_samples = df_sampled.count()
- if ((nb_samples / float(len(wikis)*samples_per_wiki)) <
samples_size_tolerance):
- raise ValueError('Collected %d samples this is less than %d%% of the
requested sample size %d'
- % (nb_samples, samples_size_tolerance*100,
samples_per_wiki))
- print 'Fetched a total of %d samples for %d wikis' % (nb_samples,
len(wikis))
- df_norm.unpersist()
+ df_sampled_raw.unpersist()
# Target around 125k rows per partition. Note that this isn't
# how many the dbn will see, because it gets collected up. Just
# a rough guess.
- dbn_partitions = int(max(200, min(2000, nb_samples / 125000 ) ))
+ dbn_partitions = int(max(200, min(2000, nb_samples / 125000)))
# Learn relevances
df_rel = (
@@ -133,6 +135,23 @@
# materialize df_hits and drop df_all_hits
df_hits.count()
df_all_hits.unpersist()
+
+ actual_samples_per_wiki =
df_hits.groupby('wikiid').agg(F.count(F.lit(1)).alias('n_obs')).collect()
+ actual_samples_per_wiki = {row.wikiid: row.n_obs for row in
actual_samples_per_wiki}
+
+ not_enough_samples = []
+ for wiki in wikis:
+ # We cant have more samples than we started with
+ expected = min(samples_per_wiki, hit_page_id_counts[wiki])
+ actual = actual_samples_per_wiki[wiki]
+ if expected / float(actual) < samples_size_tolerance:
+ not_enough_samples.append(
+ 'Collected %d samples from %s which is less than %d%% of the
requested sample size %d'
+ % (actual, wiki, samples_size_tolerance*100, expected))
+ if not_enough_samples:
+ raise ValueError('\n'.join(not_enough_samples))
+
+ print 'Fetched a total of %d samples for %d wikis' %
(sum(actual_samples_per_wiki.values()), len(wikis))
# TODO: Training is per-wiki, should this be as well?
ndcgAt10 = mjolnir.metrics.ndcg(df_hits, 10, query_cols=['wikiid',
'query'])
@@ -216,7 +235,8 @@
help='The approximate number of rows in the final result per-wiki.')
parser.add_argument(
'-qe', '--sample-size-tolerance', dest='samples_size_tolerance',
type=float, default=0.5,
- help='The tolerance between the --samples-per-wiki set and the actual
number of rows fetched.')
+ help='The tolerance between the --samples-per-wiki set and the actual
number of rows fetched.'
+ + ' Higher requires closer match.')
parser.add_argument(
'-s', '--min-sessions', dest='min_sessions_per_query', type=int,
default=10,
help='The minimum number of sessions per normalized query')
--
To view, visit https://gerrit.wikimedia.org/r/391729
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings
Gerrit-MessageType: newchange
Gerrit-Change-Id: Ib9f8d9b6204d7568e02356c1062cf3263d8eedd6
Gerrit-PatchSet: 1
Gerrit-Project: search/MjoLniR
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <[email protected]>
_______________________________________________
MediaWiki-commits mailing list
[email protected]
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits