Updated Branches: refs/heads/master 1138d3855 -> 96f5e9f8c
CRUNCH-178: Add reservoir sampling functions to lib.Sample and make the reservoir and regular sampling APIs consistent. Project: http://git-wip-us.apache.org/repos/asf/crunch/repo Commit: http://git-wip-us.apache.org/repos/asf/crunch/commit/96f5e9f8 Tree: http://git-wip-us.apache.org/repos/asf/crunch/tree/96f5e9f8 Diff: http://git-wip-us.apache.org/repos/asf/crunch/diff/96f5e9f8 Branch: refs/heads/master Commit: 96f5e9f8cbaf387c93204db2e9d15430154124cb Parents: 1138d38 Author: Josh Wills <[email protected]> Authored: Wed Mar 6 15:09:28 2013 -0800 Committer: Josh Wills <[email protected]> Committed: Wed Mar 13 01:04:09 2013 -0700 ---------------------------------------------------------------------- .../main/java/org/apache/crunch/lib/Sample.java | 197 ++++++++++++--- .../java/org/apache/crunch/lib/SampleUtils.java | 161 ++++++++++++ .../java/org/apache/crunch/lib/SampleTest.java | 37 +++- 3 files changed, 356 insertions(+), 39 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/crunch/blob/96f5e9f8/crunch/src/main/java/org/apache/crunch/lib/Sample.java ---------------------------------------------------------------------- diff --git a/crunch/src/main/java/org/apache/crunch/lib/Sample.java b/crunch/src/main/java/org/apache/crunch/lib/Sample.java index 5be2292..be75ae2 100644 --- a/crunch/src/main/java/org/apache/crunch/lib/Sample.java +++ b/crunch/src/main/java/org/apache/crunch/lib/Sample.java @@ -17,51 +17,37 @@ */ package org.apache.crunch.lib; -import java.util.Random; -import org.apache.crunch.FilterFn; +import org.apache.crunch.MapFn; import org.apache.crunch.PCollection; import org.apache.crunch.PTable; import org.apache.crunch.Pair; +import org.apache.crunch.lib.SampleUtils.ReservoirSampleFn; +import org.apache.crunch.lib.SampleUtils.SampleFn; +import org.apache.crunch.lib.SampleUtils.WRSCombineFn; +import org.apache.crunch.types.PTableType; +import org.apache.crunch.types.PType; +import org.apache.crunch.types.PTypeFamily; -import com.google.common.base.Preconditions; - +/** + * Methods for performing random sampling in a distributed fashion, either by accepting each + * record in a {@code PCollection} with an independent probability in order to sample some + * fraction of the overall data set, or by using reservoir sampling in order to pull a uniform + * or weighted sample of fixed size from a {@code PCollection} of an unknown size. For more details + * on the reservoir sampling algorithms used by this library, see the A-ES algorithm described in + * <a href="http://arxiv.org/pdf/1012.0256.pdf">Efraimidis (2012)</a>. + */ public class Sample { - private static class SamplerFn<S> extends FilterFn<S> { - - private final long seed; - private final double acceptanceProbability; - private transient Random r; - - public SamplerFn(long seed, double acceptanceProbability) { - Preconditions.checkArgument(0.0 < acceptanceProbability && acceptanceProbability < 1.0); - this.seed = seed; - this.acceptanceProbability = acceptanceProbability; - } - - @Override - public void initialize() { - if (r == null) { - r = new Random(seed); - } - } - - @Override - public boolean accept(S input) { - return r.nextDouble() < acceptanceProbability; - } - } - /** * Output records from the given {@code PCollection} with the given probability. * * @param input The {@code PCollection} to sample from - * @param probability The probability (0.0 < p < 1.0) + * @param probability The probability (0.0 < p %lt; 1.0) * @return The output {@code PCollection} created from sampling */ public static <S> PCollection<S> sample(PCollection<S> input, double probability) { - return sample(input, System.currentTimeMillis(), probability); + return sample(input, null, probability); } /** @@ -69,26 +55,163 @@ public class Sample { * testing. * * @param input The {@code PCollection} to sample from - * @param seed The seed - * @param probability The probability (0.0 < p < 1.0) + * @param seed The seed for the random number generator + * @param probability The probability (0.0 < p < 1.0) * @return The output {@code PCollection} created from sampling */ - public static <S> PCollection<S> sample(PCollection<S> input, long seed, double probability) { + public static <S> PCollection<S> sample(PCollection<S> input, Long seed, double probability) { String stageName = String.format("sample(%.2f)", probability); - return input.parallelDo(stageName, new SamplerFn<S>(seed, probability), input.getPType()); + return input.parallelDo(stageName, new SampleFn<S>(probability, seed), input.getPType()); } /** * A {@code PTable<K, V>} analogue of the {@code sample} function. + * + * @param input The {@code PTable} to sample from + * @param probability The probability (0.0 < p < 1.0) + * @return The output {@code PTable} created from sampling */ public static <K, V> PTable<K, V> sample(PTable<K, V> input, double probability) { return PTables.asPTable(sample((PCollection<Pair<K, V>>) input, probability)); } /** - * A {@code PTable<K, V>} analogue of the {@code sample} function. + * A {@code PTable<K, V>} analogue of the {@code sample} function, with the seed argument + * exposed for testing purposes. + * + * @param input The {@code PTable} to sample from + * @param seed The seed for the random number generator + * @param probability The probability (0.0 < p < 1.0) + * @return The output {@code PTable} created from sampling */ - public static <K, V> PTable<K, V> sample(PTable<K, V> input, long seed, double probability) { + public static <K, V> PTable<K, V> sample(PTable<K, V> input, Long seed, double probability) { return PTables.asPTable(sample((PCollection<Pair<K, V>>) input, seed, probability)); } + + /** + * Select a fixed number of elements from the given {@code PCollection} with each element + * equally likely to be included in the sample. + * + * @param input The input data + * @param sampleSize The number of elements to select + * @return A {@code PCollection} made up of the sampled elements + */ + public static <T> PCollection<T> reservoirSample( + PCollection<T> input, + int sampleSize) { + return reservorSample(input, sampleSize, null); + } + + /** + * A version of the reservoir sampling algorithm that uses a given seed, primarily for + * testing purposes. + * + * @param input The input data + * @param sampleSize The number of elements to select + * @param seed The test seed + * @return A {@code PCollection} made up of the sampled elements + + */ + public static <T> PCollection<T> reservorSample( + PCollection<T> input, + int sampleSize, + Long seed) { + PTypeFamily ptf = input.getTypeFamily(); + PType<Pair<T, Integer>> ptype = ptf.pairs(input.getPType(), ptf.ints()); + return weightedReservoirSample( + input.parallelDo(new MapFn<T, Pair<T, Integer>>() { + public Pair<T, Integer> map(T t) { return Pair.of(t, 1); } + }, ptype), + sampleSize, + seed); + } + + /** + * Selects a weighted sample of the elements of the given {@code PCollection}, where the second term in + * the input {@code Pair} is a numerical weight. + * + * @param input the weighted observations + * @param sampleSize The number of elements to select + * @return A random sample of the given size that respects the weighting values + */ + public static <T, N extends Number> PCollection<T> weightedReservoirSample( + PCollection<Pair<T, N>> input, + int sampleSize) { + return weightedReservoirSample(input, sampleSize, null); + } + + /** + * The weighted reservoir sampling function with the seed term exposed for testing purposes. + * + * @param input the weighted observations + * @param sampleSize The number of elements to select + * @param seed The test seed + * @return A random sample of the given size that respects the weighting values + */ + public static <T, N extends Number> PCollection<T> weightedReservoirSample( + PCollection<Pair<T, N>> input, + int sampleSize, + Long seed) { + PTypeFamily ptf = input.getTypeFamily(); + PTable<Integer, Pair<T, N>> groupedIn = input.parallelDo( + new MapFn<Pair<T, N>, Pair<Integer, Pair<T, N>>>() { + @Override + public Pair<Integer, Pair<T, N>> map(Pair<T, N> p) { + return Pair.of(0, p); + } + }, ptf.tableOf(ptf.ints(), input.getPType())); + int[] ss = new int[] { sampleSize }; + return groupedWeightedReservoirSample(groupedIn, ss, seed) + .parallelDo(new MapFn<Pair<Integer, T>, T>() { + @Override + public T map(Pair<Integer, T> p) { + return p.second(); + } + }, (PType<T>) input.getPType().getSubTypes().get(0)); + } + + /** + * The most general purpose of the weighted reservoir sampling patterns that allows us to choose + * a random sample of elements for each of N input groups. + * + * @param input A {@code PTable} with the key a group ID and the value a weighted observation in that group + * @param sampleSizes An array of length N, with each entry is the number of elements to include in that group + * @return A {@code PCollection} of the sampled elements for each of the groups + */ + + public static <T, N extends Number> PCollection<Pair<Integer, T>> groupedWeightedReservoirSample( + PTable<Integer, Pair<T, N>> input, + int[] sampleSizes) { + return groupedWeightedReservoirSample(input, sampleSizes, null); + } + + /** + * Same as the other groupedWeightedReservoirSample method, but include a seed for testing + * purposes. + * + * @param input A {@code PTable} with the key a group ID and the value a weighted observation in that group + * @param sampleSizes An array of length N, with each entry is the number of elements to include in that group + * @param seed The test seed + * @return A {@code PCollection} of the sampled elements for each of the groups + */ + public static <T, N extends Number> PCollection<Pair<Integer, T>> groupedWeightedReservoirSample( + PTable<Integer, Pair<T, N>> input, + int[] sampleSizes, + Long seed) { + PTypeFamily ptf = input.getTypeFamily(); + PType<T> ttype = (PType<T>) input.getPTableType().getValueType().getSubTypes().get(0); + PTableType<Integer, Pair<Double, T>> ptt = ptf.tableOf(ptf.ints(), + ptf.pairs(ptf.doubles(), ttype)); + + return input.parallelDo(new ReservoirSampleFn<T, N>(sampleSizes, seed), ptt) + .groupByKey(1) + .combineValues(new WRSCombineFn<T>(sampleSizes)) + .parallelDo(new MapFn<Pair<Integer, Pair<Double, T>>, Pair<Integer, T>>() { + @Override + public Pair<Integer, T> map(Pair<Integer, Pair<Double, T>> p) { + return Pair.of(p.first(), p.second().second()); + } + }, ptf.pairs(ptf.ints(), ttype)); + } + } http://git-wip-us.apache.org/repos/asf/crunch/blob/96f5e9f8/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java ---------------------------------------------------------------------- diff --git a/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java b/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java new file mode 100644 index 0000000..cbc30e4 --- /dev/null +++ b/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.crunch.lib; + +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.SortedMap; + +import org.apache.crunch.CombineFn; +import org.apache.crunch.DoFn; +import org.apache.crunch.Emitter; +import org.apache.crunch.FilterFn; +import org.apache.crunch.Pair; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +class SampleUtils { + + static class SampleFn<S> extends FilterFn<S> { + + private final Long seed; + private final double acceptanceProbability; + private transient Random r; + + public SampleFn(double acceptanceProbability, Long seed) { + Preconditions.checkArgument(0.0 < acceptanceProbability && acceptanceProbability < 1.0); + this.seed = seed == null ? System.currentTimeMillis() : seed; + this.acceptanceProbability = acceptanceProbability; + } + + @Override + public void initialize() { + if (r == null) { + r = new Random(seed); + } + } + + @Override + public boolean accept(S input) { + return r.nextDouble() < acceptanceProbability; + } + } + + + static class ReservoirSampleFn<T, N extends Number> + extends DoFn<Pair<Integer, Pair<T, N>>, Pair<Integer, Pair<Double, T>>> { + + private int[] sampleSizes; + private Long seed; + private transient List<SortedMap<Double, T>> reservoirs; + private transient Random random; + + public ReservoirSampleFn(int[] sampleSizes, Long seed) { + this.sampleSizes = sampleSizes; + this.seed = seed; + } + + @Override + public void initialize() { + this.reservoirs = Lists.newArrayList(); + for (int i = 0; i < sampleSizes.length; i++) { + reservoirs.add(Maps.<Double, T>newTreeMap()); + } + if (random == null) { + if (seed == null) { + this.random = new Random(); + } else { + this.random = new Random(seed); + } + } + } + + @Override + public void process(Pair<Integer, Pair<T, N>> input, + Emitter<Pair<Integer, Pair<Double, T>>> emitter) { + int id = input.first(); + Pair<T, N> p = input.second(); + double weight = p.second().doubleValue(); + if (weight > 0.0) { + double score = Math.log(random.nextDouble()) / weight; + SortedMap<Double, T> reservoir = reservoirs.get(id); + if (reservoir.size() < sampleSizes[id]) { + reservoir.put(score, p.first()); + } else if (score > reservoir.firstKey()) { + reservoir.remove(reservoir.firstKey()); + reservoir.put(score, p.first()); + } + } + } + + @Override + public void cleanup(Emitter<Pair<Integer, Pair<Double, T>>> emitter) { + for (int id = 0; id < reservoirs.size(); id++) { + SortedMap<Double, T> reservoir = reservoirs.get(id); + for (Map.Entry<Double, T> e : reservoir.entrySet()) { + emitter.emit(Pair.of(id, Pair.of(e.getKey(), e.getValue()))); + } + } + } + } + + static class WRSCombineFn<T> extends CombineFn<Integer, Pair<Double, T>> { + + private int[] sampleSizes; + private List<SortedMap<Double, T>> reservoirs; + + public WRSCombineFn(int[] sampleSizes) { + this.sampleSizes = sampleSizes; + } + + @Override + public void initialize() { + this.reservoirs = Lists.newArrayList(); + for (int i = 0; i < sampleSizes.length; i++) { + reservoirs.add(Maps.<Double, T>newTreeMap()); + } + } + + @Override + public void process(Pair<Integer, Iterable<Pair<Double, T>>> input, + Emitter<Pair<Integer, Pair<Double, T>>> emitter) { + SortedMap<Double, T> reservoir = reservoirs.get(input.first()); + for (Pair<Double, T> p : input.second()) { + if (reservoir.size() < sampleSizes[input.first()]) { + reservoir.put(p.first(), p.second()); + } else if (p.first() > reservoir.firstKey()) { + reservoir.remove(reservoir.firstKey()); + reservoir.put(p.first(), p.second()); + } + } + } + + @Override + public void cleanup(Emitter<Pair<Integer, Pair<Double, T>>> emitter) { + for (int i = 0; i < reservoirs.size(); i++) { + SortedMap<Double, T> reservoir = reservoirs.get(i); + for (Map.Entry<Double, T> e : reservoir.entrySet()) { + emitter.emit(Pair.of(i, Pair.of(e.getKey(), e.getValue()))); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/crunch/blob/96f5e9f8/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java ---------------------------------------------------------------------- diff --git a/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java b/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java index 69fd074..bd6fd81 100644 --- a/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java +++ b/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java @@ -20,18 +20,51 @@ package org.apache.crunch.lib; import static org.junit.Assert.assertEquals; import java.util.List; +import java.util.Map; import org.apache.crunch.PCollection; +import org.apache.crunch.Pair; import org.apache.crunch.impl.mem.MemPipeline; +import org.apache.crunch.types.writable.Writables; import org.junit.Test; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; public class SampleTest { + private PCollection<Pair<String, Double>> values = MemPipeline.typedCollectionOf( + Writables.pairs(Writables.strings(), Writables.doubles()), + ImmutableList.of( + Pair.of("foo", 200.0), + Pair.of("bar", 400.0), + Pair.of("baz", 100.0), + Pair.of("biz", 100.0))); + @Test - public void testSampler() { + public void testWRS() throws Exception { + Map<String, Integer> histogram = Maps.newHashMap(); + + for (int i = 0; i < 100; i++) { + PCollection<String> sample = Sample.weightedReservoirSample(values, 1, 1729L + i); + for (String s : sample.materialize()) { + if (!histogram.containsKey(s)) { + histogram.put(s, 1); + } else { + histogram.put(s, 1 + histogram.get(s)); + } + } + } + + Map<String, Integer> expected = ImmutableMap.of( + "foo", 24, "bar", 51, "baz", 13, "biz", 12); + assertEquals(expected, histogram); + } + + @Test + public void testSample() { PCollection<Integer> pcollect = MemPipeline.collectionOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); - Iterable<Integer> sample = Sample.sample(pcollect, 123998, 0.2).materialize(); + Iterable<Integer> sample = Sample.sample(pcollect, 123998L, 0.2).materialize(); List<Integer> sampleValues = ImmutableList.copyOf(sample); assertEquals(ImmutableList.of(6, 7), sampleValues); }
