jenkins-bot has submitted this change and it was merged. ( 
https://gerrit.wikimedia.org/r/391729 )

Change subject: Repair ability to collect data for undersized wikis
......................................................................


Repair ability to collect data for undersized wikis

When attempting to collect data for small wikis that have much less than
the provided samples_per_wiki in data_pipeline.py we would fail, because
the collected data was much less than expected. Rework this code to
allow for wikis that start with much less data than available.

While digging into this I realized that this check was being done much
too early. It was calculating against data that was not of the same
shape, so not the same counts, as the final data we feed into feature
collection. Everything between sampling and feature collection is
relatively cheap (compared to sending millions of queries to
elasticsearch) so move the check down to just before feature collection
where we know exactly how many observations we have.

Change-Id: Ib9f8d9b6204d7568e02356c1062cf3263d8eedd6
---
M mjolnir/sampling.py
M mjolnir/test/test_sampling.py
M mjolnir/utilities/data_pipeline.py
3 files changed, 53 insertions(+), 24 deletions(-)

Approvals:
  jenkins-bot: Verified
  DCausse: Looks good to me, approved



diff --git a/mjolnir/sampling.py b/mjolnir/sampling.py
index 50f527a..d7d1f2b 100644
--- a/mjolnir/sampling.py
+++ b/mjolnir/sampling.py
@@ -165,15 +165,18 @@
         .agg(F.sum('num_hit_page_ids').alias('num_hit_page_ids'))
         .collect())
 
+    hit_page_id_counts = {row.wikiid: row.num_hit_page_ids for row in 
hit_page_id_counts}
+
     wiki_percents = {}
     needs_sampling = False
-    for row in hit_page_id_counts:
-        wiki_percents[row.wikiid] = min(1., float(samples_per_wiki) / 
row.num_hit_page_ids)
-        if wiki_percents[row.wikiid] < 1.:
+
+    for wikiid, num_hit_page_ids in hit_page_id_counts.items():
+        wiki_percents[wikiid] = min(1., float(samples_per_wiki) / 
num_hit_page_ids)
+        if wiki_percents[wikiid] < 1.:
             needs_sampling = True
 
     if not needs_sampling:
-        return df
+        return hit_page_id_counts, df
 
     # Aggregate down into a unique set of (wikiid, norm_query_id) and add in a
     # count of the number of unique sessions per pair. We will sample 
per-strata
@@ -184,14 +187,17 @@
         .agg(F.countDistinct('session_id').alias('num_sessions'))
         # This rdd will be used multiple times through strata generation and
         # sampling. Cache to not duplicate the filtering and aggregation work.
-        # Spark will eventually throw this away in an LRU fashion.
         .cache())
-
-    # materialize df_queries_unique so we can unpersist the input df
-    df_queries_unique.count()
-    df.unpersist()
 
     df_queries_sampled = _sample_queries(df_queries_unique, wiki_percents, 
seed=seed)
 
     # Select the rows chosen by sampling from the input df
-    return df.join(df_queries_sampled, how='inner', on=['wikiid', 
'norm_query_id'])
+    df_sampled = (
+        df
+        .join(df_queries_sampled, how='inner', on=['wikiid', 'norm_query_id'])
+        .cache())
+    df_sampled.count()
+    df.unpersist()
+    df_queries_unique.unpersist()
+
+    return hit_page_id_counts, df_sampled
diff --git a/mjolnir/test/test_sampling.py b/mjolnir/test/test_sampling.py
index 2feeb29..66a4605 100644
--- a/mjolnir/test/test_sampling.py
+++ b/mjolnir/test/test_sampling.py
@@ -20,8 +20,9 @@
         ('foo', 'e', 5, 'eee', list(range(3))),
     ]).toDF(['wikiid', 'query', 'norm_query_id', 'session_id', 'hit_page_ids'])
 
-    sampled = mjolnir.sampling.sample(df, samples_per_wiki=100,
-                                      seed=12345).collect()
+    hit_page_id_counts, df_sampled = mjolnir.sampling.sample(
+        df, samples_per_wiki=100, seed=12345)
+    sampled = df_sampled.collect()
     # The sampling rate should have been chosen as 1.0, so we should have all 
data
     # regardless of probabilities.
     assert len(sampled) == 5
@@ -60,8 +61,8 @@
     # Using a constant seed ensures deterministic testing. Because this code
     # actually relies on the law of large numbers, and we do not have large
     # numbers here, many seeds probably fail.
-    df_sampled = mjolnir.sampling.sample(df, samples_per_wiki=samples_per_wiki,
-                                         seed=12345)
+    hit_page_id_counts, df_sampled = mjolnir.sampling.sample(
+        df, samples_per_wiki=samples_per_wiki, seed=12345)
     sampled = (
         df_sampled
         .select('wikiid', 'query', 
F.explode('hit_page_ids').alias('hit_page_id'))
diff --git a/mjolnir/utilities/data_pipeline.py 
b/mjolnir/utilities/data_pipeline.py
index 401be45..20b06f0 100644
--- a/mjolnir/utilities/data_pipeline.py
+++ b/mjolnir/utilities/data_pipeline.py
@@ -64,11 +64,19 @@
         min_sessions_per_query=min_sessions_per_query)
 
     # Sample to some subset of queries per wiki
+    hit_page_id_counts, df_sampled_raw = mjolnir.sampling.sample(
+        df_norm,
+        seed=54321,
+        samples_per_wiki=samples_per_wiki)
+
+    # This should already be cached from sample, but lets be explicit
+    # to prevent future problems with refactoring.
+    df_sampled_raw.cache().count()
+    df_norm.unpersist()
+
+    # Transform our dataframe into the shape expected by the DBN
     df_sampled = (
-        mjolnir.sampling.sample(
-            df_norm,
-            seed=54321,
-            samples_per_wiki=samples_per_wiki)
+        df_sampled_raw
         # Explode source into a row per displayed hit
         .select('*', F.expr("posexplode(hit_page_ids)").alias('hit_position', 
'hit_page_id'))
         .drop('hit_page_ids')
@@ -79,11 +87,7 @@
 
     # materialize df_sampled and unpersist df_norm
     nb_samples = df_sampled.count()
-    if ((nb_samples / float(len(wikis)*samples_per_wiki)) < 
samples_size_tolerance):
-        raise ValueError('Collected %d samples this is less than %d%% of the 
requested sample size %d'
-                         % (nb_samples, samples_size_tolerance*100, 
samples_per_wiki))
-    print 'Fetched a total of %d samples for %d wikis' % (nb_samples, 
len(wikis))
-    df_norm.unpersist()
+    df_sampled_raw.unpersist()
 
     # Target around 125k rows per partition. Note that this isn't
     # how many the dbn will see, because it gets collected up. Just
@@ -133,6 +137,23 @@
     # materialize df_hits and drop df_all_hits
     df_hits.count()
     df_all_hits.unpersist()
+
+    actual_samples_per_wiki = 
df_hits.groupby('wikiid').agg(F.count(F.lit(1)).alias('n_obs')).collect()
+    actual_samples_per_wiki = {row.wikiid: row.n_obs for row in 
actual_samples_per_wiki}
+
+    not_enough_samples = []
+    for wiki in wikis:
+        # We cant have more samples than we started with
+        expected = min(samples_per_wiki, hit_page_id_counts[wiki])
+        actual = actual_samples_per_wiki[wiki]
+        if expected / float(actual) < samples_size_tolerance:
+            not_enough_samples.append(
+                'Collected %d samples from %s which is less than %d%% of the 
requested sample size %d'
+                % (actual, wiki, samples_size_tolerance*100, expected))
+    if not_enough_samples:
+        raise ValueError('\n'.join(not_enough_samples))
+
+    print 'Fetched a total of %d samples for %d wikis' % 
(sum(actual_samples_per_wiki.values()), len(wikis))
 
     # TODO: Training is per-wiki, should this be as well?
     ndcgAt10 = mjolnir.metrics.ndcg(df_hits, 10, query_cols=['wikiid', 
'query'])
@@ -216,7 +237,8 @@
         help='The approximate number of rows in the final result per-wiki.')
     parser.add_argument(
         '-qe', '--sample-size-tolerance', dest='samples_size_tolerance', 
type=float, default=0.5,
-        help='The tolerance between the --samples-per-wiki set and the actual 
number of rows fetched.')
+        help='The tolerance between the --samples-per-wiki set and the actual 
number of rows fetched.'
+             + ' Higher requires closer match.')
     parser.add_argument(
         '-s', '--min-sessions', dest='min_sessions_per_query', type=int, 
default=10,
         help='The minimum number of sessions per normalized query')

-- 
To view, visit https://gerrit.wikimedia.org/r/391729
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: merged
Gerrit-Change-Id: Ib9f8d9b6204d7568e02356c1062cf3263d8eedd6
Gerrit-PatchSet: 3
Gerrit-Project: search/MjoLniR
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <[email protected]>
Gerrit-Reviewer: DCausse <[email protected]>
Gerrit-Reviewer: EBernhardson <[email protected]>
Gerrit-Reviewer: jenkins-bot <>

_______________________________________________
MediaWiki-commits mailing list
[email protected]
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to