Repository: flink Updated Branches: refs/heads/master 58421b848 -> c9cfb17cb
[FLINK-1901] [core] Create sample operator for Dataset. [FLINK-1901] [core] enable sample with fixed size on the whole dataset. [FLINK-1901] [core] add more comments for RandomSamplerTest. [FLINK-1901] [core] refactor PoissonSampler output Iterator. [FLINK-1901] [core] move sample/sampleWithSize operator to DataSetUtils. Adds notes for commons-math3 to LICENSE and NOTICE file This closes #949. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/c9cfb17c Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/c9cfb17c Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/c9cfb17c Branch: refs/heads/master Commit: c9cfb17cb095def8b8ea0ed1b598fc78b890b874 Parents: 58421b8 Author: chengxiang li <chengxiang...@intel.com> Authored: Wed Jul 22 11:38:13 2015 +0800 Committer: Till Rohrmann <trohrm...@apache.org> Committed: Fri Aug 21 15:41:27 2015 +0200 ---------------------------------------------------------------------- flink-dist/src/main/flink-bin/LICENSE | 1 + flink-dist/src/main/flink-bin/NOTICE | 15 + flink-java/pom.xml | 6 + .../java/org/apache/flink/api/java/DataSet.java | 2 +- .../java/org/apache/flink/api/java/Utils.java | 4 + .../api/java/functions/SampleInCoordinator.java | 71 +++ .../api/java/functions/SampleInPartition.java | 71 +++ .../api/java/functions/SampleWithFraction.java | 68 +++ .../api/java/sampling/BernoulliSampler.java | 117 +++++ .../java/sampling/DistributedRandomSampler.java | 125 +++++ .../java/sampling/IntermediateSampleData.java | 47 ++ .../flink/api/java/sampling/PoissonSampler.java | 122 +++++ .../flink/api/java/sampling/RandomSampler.java | 63 +++ .../ReservoirSamplerWithReplacement.java | 110 +++++ .../ReservoirSamplerWithoutReplacement.java | 106 +++++ .../flink/api/java/utils/DataSetUtils.java | 95 ++++ .../api/java/sampling/RandomSamplerTest.java | 452 +++++++++++++++++++ .../apache/flink/api/scala/DataSetUtils.scala | 40 +- .../apache/flink/test/util/TestBaseUtils.java | 31 ++ .../test/javaApiOperators/SampleITCase.java | 167 +++++++ .../api/scala/operators/SampleITCase.scala | 167 +++++++ pom.xml | 6 + 22 files changed, 1884 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-dist/src/main/flink-bin/LICENSE ---------------------------------------------------------------------- diff --git a/flink-dist/src/main/flink-bin/LICENSE b/flink-dist/src/main/flink-bin/LICENSE index 281b8f0..e79ff71 100644 --- a/flink-dist/src/main/flink-bin/LICENSE +++ b/flink-dist/src/main/flink-bin/LICENSE @@ -277,6 +277,7 @@ under the Apache License (v 2.0): - Uncommons Math (org.uncommons.maths:uncommons-maths:1.2.2a - https://github.com/dwdyer/uncommons-maths) - Jansi (org.fusesource.jansi:jansi:1.4 - https://github.com/fusesource/jansi) - Apache Camel Core (org.apache.camel:camel-core:2.10.3 - http://camel.apache.org/camel-core.html) + - Apache Commons Math (org.apache.commons:commons-math3:3.5 - http://commons.apache.org/proper/commons-math/index.html) ----------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-dist/src/main/flink-bin/NOTICE ---------------------------------------------------------------------- diff --git a/flink-dist/src/main/flink-bin/NOTICE b/flink-dist/src/main/flink-bin/NOTICE index a71e61d..7b0fe72 100644 --- a/flink-dist/src/main/flink-bin/NOTICE +++ b/flink-dist/src/main/flink-bin/NOTICE @@ -69,6 +69,21 @@ Copyright (c) 2002 JSON.org ----------------------------------------------------------------------- + Apache Commons Math +----------------------------------------------------------------------- + +Apache Commons Math +Copyright 2001-2015 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +This product includes software developed for Orekit by +CS Systèmes d'Information (http://www.c-s.fr/) +Copyright 2010-2012 CS Systèmes d'Information + + +----------------------------------------------------------------------- Akka ----------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/pom.xml ---------------------------------------------------------------------- diff --git a/flink-java/pom.xml b/flink-java/pom.xml index 683304f..d777048 100644 --- a/flink-java/pom.xml +++ b/flink-java/pom.xml @@ -92,6 +92,12 @@ under the License. <artifactId>guava</artifactId> <version>${guava.version}</version> </dependency> + + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-math3</artifactId> + <!-- managed version --> + </dependency> <dependency> <groupId>org.apache.flink</groupId> http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java index 81ba279..98a94c6 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java @@ -1057,7 +1057,7 @@ public abstract class DataSet<T> { public UnionOperator<T> union(DataSet<T> other){ return new UnionOperator<T>(this, other, Utils.getCallLocationName()); } - + // -------------------------------------------------------------------------------------------- // Partitioning // -------------------------------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/Utils.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/Utils.java b/flink-java/src/main/java/org/apache/flink/api/java/Utils.java index a1e3d25..785f3ce 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/Utils.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/Utils.java @@ -28,6 +28,8 @@ import org.apache.flink.api.java.typeutils.GenericTypeInfo; import java.lang.reflect.Field; import java.lang.reflect.Modifier; import java.util.List; +import java.util.Random; + import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.configuration.Configuration; @@ -36,6 +38,8 @@ import static org.apache.flink.api.java.functions.FunctionAnnotation.SkipCodeAna public class Utils { + + public static final Random RNG = new Random(); public static String getCallLocationName() { return getCallLocationName(4); http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleInCoordinator.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleInCoordinator.java b/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleInCoordinator.java new file mode 100644 index 0000000..528d746 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleInCoordinator.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.functions; + +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.java.sampling.IntermediateSampleData; +import org.apache.flink.api.java.sampling.DistributedRandomSampler; +import org.apache.flink.api.java.sampling.ReservoirSamplerWithReplacement; +import org.apache.flink.api.java.sampling.ReservoirSamplerWithoutReplacement; +import org.apache.flink.util.Collector; + +import java.util.Iterator; + +/** + * SampleInCoordinator wraps the sample logic of the coordinator side (the second phase of + * distributed sample algorithm). It executes the coordinator side sample logic in an all reduce + * function. The user needs to make sure that the operator parallelism of this function is 1 to + * make sure this is a central coordinator. Besides, we do not need the task index information for + * random generator seed as the parallelism must be 1. + * + * @param <T> the data type wrapped in ElementWithRandom as input. + */ +public class SampleInCoordinator<T> implements GroupReduceFunction<IntermediateSampleData<T>, T> { + + private boolean withReplacement; + private int numSample; + private long seed; + + /** + * Create a function instance of SampleInCoordinator. + * + * @param withReplacement Whether element can be selected more than once. + * @param numSample Fixed sample size. + * @param seed Random generator seed. + */ + public SampleInCoordinator(boolean withReplacement, int numSample, long seed) { + this.withReplacement = withReplacement; + this.numSample = numSample; + this.seed = seed; + } + + @Override + public void reduce(Iterable<IntermediateSampleData<T>> values, Collector<T> out) throws Exception { + DistributedRandomSampler<T> sampler; + if (withReplacement) { + sampler = new ReservoirSamplerWithReplacement<>(numSample, seed); + } else { + sampler = new ReservoirSamplerWithoutReplacement<>(numSample, seed); + } + + Iterator<T> sampled = sampler.sampleInCoordinator(values.iterator()); + while (sampled.hasNext()) { + out.collect(sampled.next()); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleInPartition.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleInPartition.java b/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleInPartition.java new file mode 100644 index 0000000..295fb44 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleInPartition.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.functions; + +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.sampling.IntermediateSampleData; +import org.apache.flink.api.java.sampling.DistributedRandomSampler; +import org.apache.flink.api.java.sampling.ReservoirSamplerWithReplacement; +import org.apache.flink.api.java.sampling.ReservoirSamplerWithoutReplacement; +import org.apache.flink.util.Collector; + +import java.util.Iterator; + +/** + * SampleInPartition wraps the sample logic on the partition side (the first phase of distributed + * sample algorithm). It executes the partition side sample logic in a mapPartition function. + * + * @param <T> The type of input data + */ +public class SampleInPartition<T> extends RichMapPartitionFunction<T, IntermediateSampleData<T>> { + + private boolean withReplacement; + private int numSample; + private long seed; + + /** + * Create a function instance of SampleInPartition. + * + * @param withReplacement Whether element can be selected more than once. + * @param numSample Fixed sample size. + * @param seed Random generator seed. + */ + public SampleInPartition(boolean withReplacement, int numSample, long seed) { + this.withReplacement = withReplacement; + this.numSample = numSample; + this.seed = seed; + } + + @Override + public void mapPartition(Iterable<T> values, Collector<IntermediateSampleData<T>> out) throws Exception { + DistributedRandomSampler<T> sampler; + long seedAndIndex = seed + getRuntimeContext().getIndexOfThisSubtask(); + if (withReplacement) { + sampler = new ReservoirSamplerWithReplacement<T>(numSample, seedAndIndex); + } else { + sampler = new ReservoirSamplerWithoutReplacement<T>(numSample, seedAndIndex); + } + + Iterator<IntermediateSampleData<T>> sampled = sampler.sampleInPartition(values.iterator()); + while (sampled.hasNext()) { + out.collect(sampled.next()); + } + } +} + + http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleWithFraction.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleWithFraction.java b/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleWithFraction.java new file mode 100644 index 0000000..4ef9aa0 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/functions/SampleWithFraction.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.functions; + +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.sampling.BernoulliSampler; +import org.apache.flink.api.java.sampling.PoissonSampler; +import org.apache.flink.api.java.sampling.RandomSampler; +import org.apache.flink.util.Collector; + +import java.util.Iterator; + +/** + * A map partition function wrapper for sampling algorithms with fraction, the sample algorithm + * takes the partition iterator as input. + * + * @param <T> + */ +public class SampleWithFraction<T> extends RichMapPartitionFunction<T, T> { + + private boolean withReplacement; + private double fraction; + private long seed; + + /** + * Create a function instance of SampleWithFraction. + * + * @param withReplacement Whether element can be selected more than once. + * @param fraction Probability that each element is selected. + * @param seed random number generator seed. + */ + public SampleWithFraction(boolean withReplacement, double fraction, long seed) { + this.withReplacement = withReplacement; + this.fraction = fraction; + this.seed = seed; + } + + @Override + public void mapPartition(Iterable<T> values, Collector<T> out) throws Exception { + RandomSampler<T> sampler; + long seedAndIndex = seed + getRuntimeContext().getIndexOfThisSubtask(); + if (withReplacement) { + sampler = new PoissonSampler<>(fraction, seedAndIndex); + } else { + sampler = new BernoulliSampler<>(fraction, seedAndIndex); + } + + Iterator<T> sampled = sampler.sample(values.iterator()); + while (sampled.hasNext()) { + out.collect(sampled.next()); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/sampling/BernoulliSampler.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/sampling/BernoulliSampler.java b/flink-java/src/main/java/org/apache/flink/api/java/sampling/BernoulliSampler.java new file mode 100644 index 0000000..0f5ecc6 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/sampling/BernoulliSampler.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.sampling; + +import com.google.common.base.Preconditions; + +import java.util.Iterator; +import java.util.Random; + +/** + * A sampler implementation built upon a Bernoulli trail. This sampler is used to sample with + * fraction and without replacement. Whether an element is sampled or not is determined by a + * Bernoulli experiment. + * + * @param <T> The type of sample. + */ +public class BernoulliSampler<T> extends RandomSampler<T> { + + private final double fraction; + private final Random random; + + /** + * Create a Bernoulli sampler with sample fraction and default random number generator. + * + * @param fraction Sample fraction, aka the Bernoulli sampler possibility. + */ + public BernoulliSampler(double fraction) { + this(fraction, new Random()); + } + + /** + * Create a Bernoulli sampler with sample fraction and random number generator seed. + * + * @param fraction Sample fraction, aka the Bernoulli sampler possibility. + * @param seed Random number generator seed. + */ + public BernoulliSampler(double fraction, long seed) { + this(fraction, new Random(seed)); + } + + /** + * Create a Bernoulli sampler with sample fraction and random number generator. + * + * @param fraction Sample fraction, aka the Bernoulli sampler possibility. + * @param random The random number generator. + */ + public BernoulliSampler(double fraction, Random random) { + Preconditions.checkArgument(fraction >= 0 && fraction <= 1.0d, "fraction fraction must between [0, 1]."); + this.fraction = fraction; + this.random = random; + } + + /** + * Sample the input elements, for each input element, take a Bernoulli trail for sampling. + * + * @param input Elements to be sampled. + * @return The sampled result which is lazy computed upon input elements. + */ + @Override + public Iterator<T> sample(final Iterator<T> input) { + if (fraction == 0) { + return EMPTY_ITERABLE; + } + + return new SampledIterator<T>() { + T current = null; + + @Override + public boolean hasNext() { + if (current == null) { + current = getNextSampledElement(); + } + + return current != null; + } + + @Override + public T next() { + if (current == null) { + return getNextSampledElement(); + } else { + T result = current; + current = null; + + return result; + } + } + + private T getNextSampledElement() { + while (input.hasNext()) { + T element = input.next(); + + if (random.nextDouble() <= fraction) { + return element; + } + } + + return null; + } + }; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/sampling/DistributedRandomSampler.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/sampling/DistributedRandomSampler.java b/flink-java/src/main/java/org/apache/flink/api/java/sampling/DistributedRandomSampler.java new file mode 100644 index 0000000..e5a719f --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/sampling/DistributedRandomSampler.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.sampling; + +import java.util.Iterator; +import java.util.PriorityQueue; + +/** + * For sampling with fraction, the sample algorithms are natively distributed, while it's not + * true for fixed size sample algorithms. The fixed size sample algorithms require two-phases + * sampling (according to our current implementation). In the first phase, each distributed + * partition is sampled independently. The partial sampling results are handled by a central + * coordinator. The central coordinator combines the partial sampling results to form the final + * result. + * + * @param <T> The input data type. + */ +public abstract class DistributedRandomSampler<T> extends RandomSampler<T> { + + protected final int numSamples; + + public DistributedRandomSampler(int numSamples) { + this.numSamples = numSamples; + } + + protected final Iterator<IntermediateSampleData<T>> EMPTY_INTERMEDIATE_ITERABLE = + new SampledIterator<IntermediateSampleData<T>>() { + @Override + public boolean hasNext() { + return false; + } + + @Override + public IntermediateSampleData<T> next() { + return null; + } + }; + + /** + * Sample algorithm for the first phase. It operates on a single partition. + * + * @param input The DataSet input of each partition. + * @return Intermediate sample output which will be used as the input of the second phase. + */ + public abstract Iterator<IntermediateSampleData<T>> sampleInPartition(Iterator<T> input); + + /** + * Sample algorithm for the second phase. This operation should be executed as the UDF of + * an all reduce operation. + * + * @param input The intermediate sample output generated in the first phase. + * @return The sampled output. + */ + public Iterator<T> sampleInCoordinator(Iterator<IntermediateSampleData<T>> input) { + if (numSamples == 0) { + return EMPTY_ITERABLE; + } + + // This queue holds fixed number elements with the top K weight for the coordinator. + PriorityQueue<IntermediateSampleData<T>> reservoir = new PriorityQueue<IntermediateSampleData<T>>(numSamples); + int index = 0; + IntermediateSampleData<T> smallest = null; + while (input.hasNext()) { + IntermediateSampleData<T> element = input.next(); + if (index < numSamples) { + // Fill the queue with first K elements from input. + reservoir.add(element); + smallest = reservoir.peek(); + } else { + // If current element weight is larger than the smallest one in queue, remove the element + // with the smallest weight, and append current element into the queue. + if (element.getWeight() > smallest.getWeight()) { + reservoir.remove(); + reservoir.add(element); + smallest = reservoir.peek(); + } + } + index++; + } + final Iterator<IntermediateSampleData<T>> itr = reservoir.iterator(); + + return new Iterator<T>() { + @Override + public boolean hasNext() { + return itr.hasNext(); + } + + @Override + public T next() { + return itr.next().getElement(); + } + + @Override + public void remove() { + itr.remove(); + } + }; + } + + /** + * Combine the first phase and second phase in sequence, implemented for test purpose only. + * + * @param input Source data. + * @return Sample result in sequence. + */ + @Override + public Iterator<T> sample(Iterator<T> input) { + return sampleInCoordinator(sampleInPartition(input)); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/sampling/IntermediateSampleData.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/sampling/IntermediateSampleData.java b/flink-java/src/main/java/org/apache/flink/api/java/sampling/IntermediateSampleData.java new file mode 100644 index 0000000..1d70f19 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/sampling/IntermediateSampleData.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.sampling; + +/** + * The data structure which is transferred between partitions and the coordinator for distributed + * random sampling. + * + * @param <T> The type of sample data. + */ +public class IntermediateSampleData<T> implements Comparable<IntermediateSampleData<T>> { + private double weight; + private T element; + + public IntermediateSampleData(double weight, T element) { + this.weight = weight; + this.element = element; + } + + public double getWeight() { + return weight; + } + + public T getElement() { + return element; + } + + @Override + public int compareTo(IntermediateSampleData<T> other) { + return this.weight >= other.getWeight() ? 1 : -1; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/sampling/PoissonSampler.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/sampling/PoissonSampler.java b/flink-java/src/main/java/org/apache/flink/api/java/sampling/PoissonSampler.java new file mode 100644 index 0000000..3834d24 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/sampling/PoissonSampler.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.sampling; + +import com.google.common.base.Preconditions; +import org.apache.commons.math3.distribution.PoissonDistribution; + +import java.util.Iterator; + +/** + * A sampler implementation based on the Poisson Distribution. While sampling elements with fraction + * and replacement, the selected number of each element follows a given poisson distribution. + * + * @param <T> The type of sample. + * @see <a href="https://en.wikipedia.org/wiki/Poisson_distribution">https://en.wikipedia.org/wiki/Poisson_distribution</a> + */ +public class PoissonSampler<T> extends RandomSampler<T> { + + private PoissonDistribution poissonDistribution; + private final double fraction; + + /** + * Create a poisson sampler which can sample elements with replacement. + * + * @param fraction The expected count of each element. + * @param seed Random number generator seed for internal PoissonDistribution. + */ + public PoissonSampler(double fraction, long seed) { + Preconditions.checkArgument(fraction >= 0, "fraction should be positive."); + this.fraction = fraction; + if (this.fraction > 0) { + this.poissonDistribution = new PoissonDistribution(fraction); + this.poissonDistribution.reseedRandomGenerator(seed); + } + } + + /** + * Create a poisson sampler which can sample elements with replacement. + * + * @param fraction The expected count of each element. + */ + public PoissonSampler(double fraction) { + Preconditions.checkArgument(fraction >= 0, "fraction should be non-negative."); + this.fraction = fraction; + if (this.fraction > 0) { + this.poissonDistribution = new PoissonDistribution(fraction); + } + } + + /** + * Sample the input elements, for each input element, generate its count following a poisson + * distribution. + * + * @param input Elements to be sampled. + * @return The sampled result which is lazy computed upon input elements. + */ + @Override + public Iterator<T> sample(final Iterator<T> input) { + if (fraction == 0) { + return EMPTY_ITERABLE; + } + + return new SampledIterator<T>() { + T currentElement; + int currentCount = 0; + + @Override + public boolean hasNext() { + if (currentCount > 0) { + return true; + } else { + moveToNextElement(); + + if (currentCount > 0) { + return true; + } else { + return false; + } + } + } + + private void moveToNextElement() { + while (input.hasNext()) { + currentElement = input.next(); + currentCount = poissonDistribution.sample(); + if (currentCount > 0) { + break; + } + } + } + + @Override + public T next() { + if (currentCount == 0) { + moveToNextElement(); + } + + if (currentCount == 0) { + return null; + } else { + currentCount--; + return currentElement; + } + } + }; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/sampling/RandomSampler.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/sampling/RandomSampler.java b/flink-java/src/main/java/org/apache/flink/api/java/sampling/RandomSampler.java new file mode 100644 index 0000000..5fe2920 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/sampling/RandomSampler.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.sampling; + +import java.util.Iterator; + +/** + * A data sample is a set of data selected from a statistical population by a defined procedure. + * RandomSampler helps to create data sample randomly. + * + * @param <T> The type of sampler data. + */ +public abstract class RandomSampler<T> { + + protected final Iterator<T> EMPTY_ITERABLE = new SampledIterator<T>() { + @Override + public boolean hasNext() { + return false; + } + + @Override + public T next() { + return null; + } + }; + + /** + * Randomly sample the elements from input in sequence, and return the result iterator. + * + * @param input Source data + * @return The sample result. + */ + public abstract Iterator<T> sample(Iterator<T> input); + +} + +/** + * A simple abstract iterator which implements the remove method as unsupported operation. + * + * @param <T> The type of iterator data. + */ +abstract class SampledIterator<T> implements Iterator<T> { + @Override + public void remove() { + throw new UnsupportedOperationException("Do not support this operation."); + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/sampling/ReservoirSamplerWithReplacement.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/sampling/ReservoirSamplerWithReplacement.java b/flink-java/src/main/java/org/apache/flink/api/java/sampling/ReservoirSamplerWithReplacement.java new file mode 100644 index 0000000..9c37154 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/sampling/ReservoirSamplerWithReplacement.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.sampling; + +import com.google.common.base.Preconditions; + +import java.util.Iterator; +import java.util.PriorityQueue; +import java.util.Random; + +/** + * A simple in memory implementation of Reservoir Sampling with replacement and with only one pass + * through the input iteration whose size is unpredictable. The basic idea behind this sampler + * implementation is quite similar to {@link ReservoirSamplerWithoutReplacement}. The main + * difference is that, in the first phase, we generate weights for each element K times, so that + * each element can get selected multiple times. + * + * This implementation refers to the algorithm described in <a href="researcher.ibm.com/files/us-dpwoodru/tw11.pdf"> + * "Optimal Random Sampling from Distributed Streams Revisited"</a>. + * + * @param <T> The type of sample. + */ +public class ReservoirSamplerWithReplacement<T> extends DistributedRandomSampler<T> { + + private final Random random; + + /** + * Create a sampler with fixed sample size and default random number generator. + * + * @param numSamples Number of selected elements, must be non-negative. + */ + public ReservoirSamplerWithReplacement(int numSamples) { + this(numSamples, new Random()); + } + + /** + * Create a sampler with fixed sample size and random number generator seed. + * + * @param numSamples Number of selected elements, must be non-negative. + * @param seed Random number generator seed + */ + public ReservoirSamplerWithReplacement(int numSamples, long seed) { + this(numSamples, new Random(seed)); + } + + /** + * Create a sampler with fixed sample size and random number generator. + * + * @param numSamples Number of selected elements, must be non-negative. + * @param random Random number generator + */ + public ReservoirSamplerWithReplacement(int numSamples, Random random) { + super(numSamples); + Preconditions.checkArgument(numSamples >= 0, "numSamples should be non-negative."); + this.random = random; + } + + @Override + public Iterator<IntermediateSampleData<T>> sampleInPartition(Iterator<T> input) { + if (numSamples == 0) { + return EMPTY_INTERMEDIATE_ITERABLE; + } + + // This queue holds a fixed number of elements with the top K weight for current partition. + PriorityQueue<IntermediateSampleData<T>> queue = new PriorityQueue<IntermediateSampleData<T>>(numSamples); + + IntermediateSampleData<T> smallest = null; + + if (input.hasNext()) { + T element = input.next(); + // Initiate the queue with the first element and random weights. + for (int i = 0; i < numSamples; i++) { + queue.add(new IntermediateSampleData<T>(random.nextDouble(), element)); + smallest = queue.peek(); + } + } + + while (input.hasNext()) { + T element = input.next(); + // To sample with replacement, we generate K random weights for each element, so that it's + // possible to be selected multi times. + for (int i = 0; i < numSamples; i++) { + // If current element weight is larger than the smallest one in queue, remove the element + // with the smallest weight, and append current element into the queue. + double rand = random.nextDouble(); + if (rand > smallest.getWeight()) { + queue.remove(); + queue.add(new IntermediateSampleData<T>(rand, element)); + smallest = queue.peek(); + } + } + } + return queue.iterator(); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/sampling/ReservoirSamplerWithoutReplacement.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/sampling/ReservoirSamplerWithoutReplacement.java b/flink-java/src/main/java/org/apache/flink/api/java/sampling/ReservoirSamplerWithoutReplacement.java new file mode 100644 index 0000000..b953bff --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/sampling/ReservoirSamplerWithoutReplacement.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.sampling; + +import com.google.common.base.Preconditions; + +import java.util.Iterator; +import java.util.PriorityQueue; +import java.util.Random; + +/** + * A simple in memory implementation of Reservoir Sampling without replacement, and with only one + * pass through the input iteration whose size is unpredictable. The basic idea behind this sampler + * implementation is to generate a random number for each input element as its weight, select the + * top K elements with max weight. As the weights are generated randomly, so are the selected + * top K elements. The algorithm is implemented using the {@link DistributedRandomSampler} + * interface. In the first phase, we generate random numbers as the weights for each element and + * select top K elements as the output of each partitions. In the second phase, we select top K + * elements from all the outputs of the first phase. + * + * This implementation refers to the algorithm described in <a href="researcher.ibm.com/files/us-dpwoodru/tw11.pdf"> + * "Optimal Random Sampling from Distributed Streams Revisited"</a>. + * + * @param <T> The type of the sampler. + */ +public class ReservoirSamplerWithoutReplacement<T> extends DistributedRandomSampler<T> { + + private final Random random; + + /** + * Create a new sampler with reservoir size and a supplied random number generator. + * + * @param numSamples Maximum number of samples to retain in reservoir, must be non-negative. + * @param random Instance of random number generator for sampling. + */ + public ReservoirSamplerWithoutReplacement(int numSamples, Random random) { + super(numSamples); + Preconditions.checkArgument(numSamples >= 0, "numSamples should be non-negative."); + this.random = random; + } + + /** + * Create a new sampler with reservoir size and a default random number generator. + * + * @param numSamples Maximum number of samples to retain in reservoir, must be non-negative. + */ + public ReservoirSamplerWithoutReplacement(int numSamples) { + this(numSamples, new Random()); + } + + /** + * Create a new sampler with reservoir size and the seed for random number generator. + * + * @param numSamples Maximum number of samples to retain in reservoir, must be non-negative. + * @param seed Random number generator seed. + */ + public ReservoirSamplerWithoutReplacement(int numSamples, long seed) { + + this(numSamples, new Random(seed)); + } + + @Override + public Iterator<IntermediateSampleData<T>> sampleInPartition(Iterator<T> input) { + if (numSamples == 0) { + return EMPTY_INTERMEDIATE_ITERABLE; + } + + // This queue holds fixed number elements with the top K weight for current partition. + PriorityQueue<IntermediateSampleData<T>> queue = new PriorityQueue<IntermediateSampleData<T>>(numSamples); + int index = 0; + IntermediateSampleData<T> smallest = null; + while (input.hasNext()) { + T element = input.next(); + if (index < numSamples) { + // Fill the queue with first K elements from input. + queue.add(new IntermediateSampleData<T>(random.nextDouble(), element)); + smallest = queue.peek(); + } else { + double rand = random.nextDouble(); + // Remove the element with the smallest weight, and append current element into the queue. + if (rand > smallest.getWeight()) { + queue.remove(); + queue.add(new IntermediateSampleData<T>(rand, element)); + smallest = queue.peek(); + } + } + index++; + } + return queue.iterator(); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java b/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java index 142e7cf..d268925 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java @@ -19,7 +19,14 @@ package org.apache.flink.api.java.utils; import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.sampling.IntermediateSampleData; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.Utils; +import org.apache.flink.api.java.functions.SampleInCoordinator; +import org.apache.flink.api.java.functions.SampleInPartition; +import org.apache.flink.api.java.functions.SampleWithFraction; +import org.apache.flink.api.java.operators.GroupReduceOperator; +import org.apache.flink.api.java.operators.MapPartitionOperator; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.util.Collector; @@ -142,6 +149,94 @@ public class DataSetUtils { }); } + // -------------------------------------------------------------------------------------------- + // Sample + // -------------------------------------------------------------------------------------------- + + /** + * Generate a sample of DataSet by the probability fraction of each element. + * + * @param withReplacement Whether element can be selected more than once. + * @param fraction Probability that each element is chosen, should be [0,1] without replacement, + * and [0, â) with replacement. While fraction is larger than 1, the elements are + * expected to be selected multi times into sample on average. + * @return The sampled DataSet + */ + public static <T> MapPartitionOperator<T, T> sample( + DataSet <T> input, + final boolean withReplacement, + final double fraction) { + + return sample(input, withReplacement, fraction, Utils.RNG.nextLong()); + } + + /** + * Generate a sample of DataSet by the probability fraction of each element. + * + * @param withReplacement Whether element can be selected more than once. + * @param fraction Probability that each element is chosen, should be [0,1] without replacement, + * and [0, â) with replacement. While fraction is larger than 1, the elements are + * expected to be selected multi times into sample on average. + * @param seed random number generator seed. + * @return The sampled DataSet + */ + public static <T> MapPartitionOperator<T, T> sample( + DataSet <T> input, + final boolean withReplacement, + final double fraction, + final long seed) { + + return input.mapPartition(new SampleWithFraction<T>(withReplacement, fraction, seed)); + } + + /** + * Generate a sample of DataSet which contains fixed size elements. + * <p> + * <strong>NOTE:</strong> Sample with fixed size is not as efficient as sample with fraction, use sample with + * fraction unless you need exact precision. + * <p/> + * + * @param withReplacement Whether element can be selected more than once. + * @param numSample The expected sample size. + * @return The sampled DataSet + */ + public static <T> DataSet<T> sampleWithSize( + DataSet <T> input, + final boolean withReplacement, + final int numSample) { + + return sampleWithSize(input, withReplacement, numSample, Utils.RNG.nextLong()); + } + + /** + * Generate a sample of DataSet which contains fixed size elements. + * <p> + * <strong>NOTE:</strong> Sample with fixed size is not as efficient as sample with fraction, use sample with + * fraction unless you need exact precision. + * <p/> + * + * @param withReplacement Whether element can be selected more than once. + * @param numSample The expected sample size. + * @param seed Random number generator seed. + * @return The sampled DataSet + */ + public static <T> DataSet<T> sampleWithSize( + DataSet <T> input, + final boolean withReplacement, + final int numSample, + final long seed) { + + SampleInPartition sampleInPartition = new SampleInPartition<T>(withReplacement, numSample, seed); + MapPartitionOperator mapPartitionOperator = input.mapPartition(sampleInPartition); + + // There is no previous group, so the parallelism of GroupReduceOperator is always 1. + String callLocation = Utils.getCallLocationName(); + SampleInCoordinator<T> sampleInCoordinator = new SampleInCoordinator<T>(withReplacement, numSample, seed); + return new GroupReduceOperator<IntermediateSampleData<T>, T>(mapPartitionOperator, + input.getType(), sampleInCoordinator, callLocation); + } + + // ************************************************************************* // UTIL METHODS // ************************************************************************* http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-java/src/test/java/org/apache/flink/api/java/sampling/RandomSamplerTest.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/sampling/RandomSamplerTest.java b/flink-java/src/test/java/org/apache/flink/api/java/sampling/RandomSamplerTest.java new file mode 100644 index 0000000..83e5b41 --- /dev/null +++ b/flink-java/src/test/java/org/apache/flink/api/java/sampling/RandomSamplerTest.java @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java.sampling; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * This test suite try to verify whether all the random samplers work as we expected, which mainly focus on: + * <ul> + * <li>Does sampled result fit into input parameters? we check parameters like sample fraction, sample size, + * w/o replacement, and so on.</li> + * <li>Does sampled result randomly selected? we verify this by measure how much does it distributed on source data. + * Run Kolmogorov-Smirnov (KS) test between the random samplers and default reference samplers which is distributed + * well-proportioned on source data. If random sampler select elements randomly from source, it would distributed + * well-proportioned on source data as well. The KS test will fail to strongly reject the null hypothesis that + * the distributions of sampling gaps are the same. + * </li> + * </ul> + * + * @see <a href="https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test">Kolmogorov Smirnov test</a> + */ +public class RandomSamplerTest { + private final static int SOURCE_SIZE = 10000; + private static KolmogorovSmirnovTest ksTest; + private static List<Double> source; + private final static int DEFFAULT_PARTITION_NUMBER=10; + private List<Double>[] sourcePartitions = new List[DEFFAULT_PARTITION_NUMBER]; + + @BeforeClass + public static void init() { + // initiate source data set. + source = new ArrayList<Double>(SOURCE_SIZE); + for (int i = 0; i < SOURCE_SIZE; i++) { + source.add((double) i); + } + + ksTest = new KolmogorovSmirnovTest(); + } + + private void initSourcePartition() { + for (int i=0; i<DEFFAULT_PARTITION_NUMBER; i++) { + sourcePartitions[i] = new LinkedList<Double>(); + } + for (int i = 0; i< SOURCE_SIZE; i++) { + int index = i % DEFFAULT_PARTITION_NUMBER; + sourcePartitions[index].add((double)i); + } + } + + @Test(expected = java.lang.IllegalArgumentException.class) + public void testBernoulliSamplerWithUnexpectedFraction1() { + verifySamplerFraction(-1, false); + } + + @Test(expected = java.lang.IllegalArgumentException.class) + public void testBernoulliSamplerWithUnexpectedFraction2() { + verifySamplerFraction(2, false); + } + + @Test + public void testBernoulliSamplerFraction() { + verifySamplerFraction(0.01, false); + verifySamplerFraction(0.05, false); + verifySamplerFraction(0.1, false); + verifySamplerFraction(0.3, false); + verifySamplerFraction(0.5, false); + verifySamplerFraction(0.854, false); + verifySamplerFraction(0.99, false); + } + + @Test + public void testBernoulliSamplerDuplicateElements() { + verifyRandomSamplerDuplicateElements(new BernoulliSampler<Double>(0.01)); + verifyRandomSamplerDuplicateElements(new BernoulliSampler<Double>(0.1)); + verifyRandomSamplerDuplicateElements(new BernoulliSampler<Double>(0.5)); + } + + @Test(expected = java.lang.IllegalArgumentException.class) + public void testPoissonSamplerWithUnexpectedFraction1() { + verifySamplerFraction(-1, true); + } + + @Test + public void testPoissonSamplerFraction() { + verifySamplerFraction(0.01, true); + verifySamplerFraction(0.05, true); + verifySamplerFraction(0.1, true); + verifySamplerFraction(0.5, true); + verifySamplerFraction(0.854, true); + verifySamplerFraction(0.99, true); + verifySamplerFraction(1.5, true); + } + + @Test(expected = java.lang.IllegalArgumentException.class) + public void testReservoirSamplerUnexpectedSize1() { + verifySamplerFixedSampleSize(-1, true); + } + + @Test(expected = java.lang.IllegalArgumentException.class) + public void testReservoirSamplerUnexpectedSize2() { + verifySamplerFixedSampleSize(-1, false); + } + + @Test + public void testBernoulliSamplerDistribution() { + verifyBernoulliSampler(0.01d); + verifyBernoulliSampler(0.05d); + verifyBernoulliSampler(0.1d); + verifyBernoulliSampler(0.5d); + } + + @Test + public void testPoissonSamplerDistribution() { + verifyPoissonSampler(0.01d); + verifyPoissonSampler(0.05d); + verifyPoissonSampler(0.1d); + verifyPoissonSampler(0.5d); + } + + @Test + public void testReservoirSamplerSampledSize() { + verifySamplerFixedSampleSize(1, true); + verifySamplerFixedSampleSize(10, true); + verifySamplerFixedSampleSize(100, true); + verifySamplerFixedSampleSize(1234, true); + verifySamplerFixedSampleSize(9999, true); + verifySamplerFixedSampleSize(20000, true); + + verifySamplerFixedSampleSize(1, false); + verifySamplerFixedSampleSize(10, false); + verifySamplerFixedSampleSize(100, false); + verifySamplerFixedSampleSize(1234, false); + verifySamplerFixedSampleSize(9999, false); + } + + @Test + public void testReservoirSamplerSampledSize2() { + RandomSampler<Double> sampler = new ReservoirSamplerWithoutReplacement<Double>(20000); + Iterator<Double> sampled = sampler.sample(source.iterator()); + assertTrue("ReservoirSamplerWithoutReplacement sampled output size should not beyond the source size.", getSize(sampled) == SOURCE_SIZE); + } + + @Test + public void testReservoirSamplerDuplicateElements() { + verifyRandomSamplerDuplicateElements(new ReservoirSamplerWithoutReplacement<Double>(100)); + verifyRandomSamplerDuplicateElements(new ReservoirSamplerWithoutReplacement<Double>(1000)); + verifyRandomSamplerDuplicateElements(new ReservoirSamplerWithoutReplacement<Double>(5000)); + } + + @Test + public void testReservoirSamplerWithoutReplacement() { + verifyReservoirSamplerWithoutReplacement(100, false); + verifyReservoirSamplerWithoutReplacement(500, false); + verifyReservoirSamplerWithoutReplacement(1000, false); + verifyReservoirSamplerWithoutReplacement(5000, false); + } + + @Test + public void testReservoirSamplerWithReplacement() { + verifyReservoirSamplerWithReplacement(100, false); + verifyReservoirSamplerWithReplacement(500, false); + verifyReservoirSamplerWithReplacement(1000, false); + verifyReservoirSamplerWithReplacement(5000, false); + } + + @Test + public void testReservoirSamplerWithMultiSourcePartitions1() { + initSourcePartition(); + + verifyReservoirSamplerWithoutReplacement(100, true); + verifyReservoirSamplerWithoutReplacement(500, true); + verifyReservoirSamplerWithoutReplacement(1000, true); + verifyReservoirSamplerWithoutReplacement(5000, true); + } + + @Test + public void testReservoirSamplerWithMultiSourcePartitions2() { + initSourcePartition(); + + verifyReservoirSamplerWithReplacement(100, true); + verifyReservoirSamplerWithReplacement(500, true); + verifyReservoirSamplerWithReplacement(1000, true); + verifyReservoirSamplerWithReplacement(5000, true); + } + + /* + * Sample with fixed size, verify whether the sampled result size equals to input size. + */ + private void verifySamplerFixedSampleSize(int numSample, boolean withReplacement) { + RandomSampler<Double> sampler; + if (withReplacement) { + sampler = new ReservoirSamplerWithReplacement<Double>(numSample); + } else { + sampler = new ReservoirSamplerWithoutReplacement<Double>(numSample); + } + Iterator<Double> sampled = sampler.sample(source.iterator()); + assertEquals(numSample, getSize(sampled)); + } + + /* + * Sample with fraction, and verify whether the sampled result close to input fraction. + */ + private void verifySamplerFraction(double fraction, boolean withReplacement) { + RandomSampler<Double> sampler; + if (withReplacement) { + sampler = new PoissonSampler<Double>(fraction); + } else { + sampler = new BernoulliSampler<Double>(fraction); + } + + // take 5 times sample, and take the average result size for next step comparison. + int totalSampledSize = 0; + double sampleCount = 5; + for (int i = 0; i < sampleCount; i++) { + totalSampledSize += getSize(sampler.sample(source.iterator())); + } + double resultFraction = totalSampledSize / ((double) SOURCE_SIZE * sampleCount); + assertTrue(String.format("expected fraction: %f, result fraction: %f", fraction, resultFraction), Math.abs((resultFraction - fraction) / fraction) < 0.1); + } + + /* + * Test sampler without replacement, and verify that there should not exist any duplicate element in sampled result. + */ + private void verifyRandomSamplerDuplicateElements(final RandomSampler<Double> sampler) { + List<Double> list = Lists.newLinkedList(new Iterable<Double>() { + @Override + public Iterator<Double> iterator() { + return sampler.sample(source.iterator()); + } + }); + Set<Double> set = Sets.newHashSet(list); + assertTrue("There should not have duplicate element for sampler without replacement.", list.size() == set.size()); + } + + private int getSize(Iterator iterator) { + int size = 0; + while (iterator.hasNext()) { + iterator.next(); + size++; + } + return size; + } + + private void verifyBernoulliSampler(double fraction) { + BernoulliSampler<Double> sampler = new BernoulliSampler<Double>(fraction); + verifyRandomSamplerWithFraction(fraction, sampler, true); + verifyRandomSamplerWithFraction(fraction, sampler, false); + } + + private void verifyPoissonSampler(double fraction) { + PoissonSampler<Double> sampler = new PoissonSampler<Double>(fraction); + verifyRandomSamplerWithFraction(fraction, sampler, true); + verifyRandomSamplerWithFraction(fraction, sampler, false); + } + + private void verifyReservoirSamplerWithReplacement(int numSamplers, boolean sampleOnPartitions) { + ReservoirSamplerWithReplacement<Double> sampler = new ReservoirSamplerWithReplacement<Double>(numSamplers); + verifyRandomSamplerWithSampleSize(numSamplers, sampler, true, sampleOnPartitions); + verifyRandomSamplerWithSampleSize(numSamplers, sampler, false, sampleOnPartitions); + } + + private void verifyReservoirSamplerWithoutReplacement(int numSamplers, boolean sampleOnPartitions) { + ReservoirSamplerWithoutReplacement<Double> sampler = new ReservoirSamplerWithoutReplacement<Double>(numSamplers); + verifyRandomSamplerWithSampleSize(numSamplers, sampler, true, sampleOnPartitions); + verifyRandomSamplerWithSampleSize(numSamplers, sampler, false, sampleOnPartitions); + } + + /* + * Verify whether random sampler sample with fraction from source data randomly. There are two default sample, one is + * sampled from source data with certain interval, the other is sampled only from the first half region of source data, + * If random sampler select elements randomly from source, it would distributed well-proportioned on source data as well, + * so the K-S Test result would accept the first one, while reject the second one. + */ + private void verifyRandomSamplerWithFraction(double fraction, RandomSampler sampler, boolean withDefaultSampler) { + double[] baseSample; + if (withDefaultSampler) { + baseSample = getDefaultSampler(fraction); + } else { + baseSample = getWrongSampler(fraction); + } + + verifyKSTest(sampler, baseSample, withDefaultSampler); + } + + /* + * Verify whether random sampler sample with fixed size from source data randomly. There are two default sample, one is + * sampled from source data with certain interval, the other is sampled only from the first half region of source data, + * If random sampler select elements randomly from source, it would distributed well-proportioned on source data as well, + * so the K-S Test result would accept the first one, while reject the second one. + */ + private void verifyRandomSamplerWithSampleSize(int sampleSize, RandomSampler sampler, boolean withDefaultSampler, boolean sampleWithPartitions) { + double[] baseSample; + if (withDefaultSampler) { + baseSample = getDefaultSampler(sampleSize); + } else { + baseSample = getWrongSampler(sampleSize); + } + + verifyKSTest(sampler, baseSample, withDefaultSampler, sampleWithPartitions); + } + + private void verifyKSTest(RandomSampler sampler, double[] defaultSampler, boolean expectSuccess) { + verifyKSTest(sampler, defaultSampler, expectSuccess, false); + } + + private void verifyKSTest(RandomSampler sampler, double[] defaultSampler, boolean expectSuccess, boolean sampleOnPartitions) { + double[] sampled = getSampledOutput(sampler, sampleOnPartitions); + double pValue = ksTest.kolmogorovSmirnovStatistic(sampled, defaultSampler); + double dValue = getDValue(sampled.length, defaultSampler.length); + if (expectSuccess) { + assertTrue(String.format("KS test result with p value(%f), d value(%f)", pValue, dValue), pValue <= dValue); + } else { + assertTrue(String.format("KS test result with p value(%f), d value(%f)", pValue, dValue), pValue > dValue); + } + } + + private double[] getSampledOutput(RandomSampler<Double> sampler, boolean sampleOnPartitions) { + Iterator<Double> sampled = null; + if (sampleOnPartitions) { + DistributedRandomSampler<Double> reservoirRandomSampler = (DistributedRandomSampler<Double>)sampler; + List<IntermediateSampleData<Double>> intermediateResult = Lists.newLinkedList(); + for (int i=0; i<DEFFAULT_PARTITION_NUMBER; i++) { + Iterator<IntermediateSampleData<Double>> partialIntermediateResult = reservoirRandomSampler.sampleInPartition(sourcePartitions[i].iterator()); + while (partialIntermediateResult.hasNext()) { + intermediateResult.add(partialIntermediateResult.next()); + } + } + sampled = reservoirRandomSampler.sampleInCoordinator(intermediateResult.iterator()); + } else { + sampled = sampler.sample(source.iterator()); + } + List<Double> list = Lists.newArrayList(); + while (sampled.hasNext()) { + list.add(sampled.next()); + } + double[] result = transferFromListToArrayWithOrder(list); + return result; + } + + /* + * Some sample result may not order by the input sequence, we should make it in order to do K-S test. + */ + private double[] transferFromListToArrayWithOrder(List<Double> list) { + Collections.sort(list, new Comparator<Double>() { + @Override + public int compare(Double o1, Double o2) { + return o1 - o2 >= 0 ? 1 : -1; + } + }); + double[] result = new double[list.size()]; + for (int i = 0; i < list.size(); i++) { + result[i] = list.get(i); + } + return result; + } + + private double[] getDefaultSampler(double fraction) { + Preconditions.checkArgument(fraction > 0, "Sample fraction should be positive."); + int size = (int) (SOURCE_SIZE * fraction); + double step = 1 / fraction; + double[] defaultSampler = new double[size]; + for (int i = 0; i < size; i++) { + defaultSampler[i] = Math.round(step * i); + } + + return defaultSampler; + } + + private double[] getDefaultSampler(int fixSize) { + Preconditions.checkArgument(fixSize > 0, "Sample fraction should be positive."); + int size = fixSize; + double step = SOURCE_SIZE / (double) fixSize; + double[] defaultSampler = new double[size]; + for (int i = 0; i < size; i++) { + defaultSampler[i] = Math.round(step * i); + } + + return defaultSampler; + } + + /* + * Build a failed sample distribution which only contains elements in the first half of source data. + */ + private double[] getWrongSampler(double fraction) { + Preconditions.checkArgument(fraction > 0, "Sample size should be positive."); + int size = (int) (SOURCE_SIZE * fraction); + int halfSourceSize = SOURCE_SIZE / 2; + double[] wrongSampler = new double[size]; + for (int i = 0; i < size; i++) { + wrongSampler[i] = (double) i % halfSourceSize; + } + + return wrongSampler; + } + + /* + * Build a failed sample distribution which only contains elements in the first half of source data. + */ + private double[] getWrongSampler(int fixSize) { + Preconditions.checkArgument(fixSize > 0, "Sample size be positive."); + int halfSourceSize = SOURCE_SIZE / 2; + int size = fixSize; + double[] wrongSampler = new double[size]; + for (int i = 0; i < size; i++) { + wrongSampler[i] = (double) i % halfSourceSize; + } + + return wrongSampler; + } + + /* + * Calculate the D value of K-S test for p-value 0.05, m and n are the sample size + */ + private double getDValue(int m, int n) { + Preconditions.checkArgument(m > 0, "input sample size should be positive."); + Preconditions.checkArgument(n > 0, "input sample size should be positive."); + double first = (double) m; + double second = (double) n; + return 1.36 * Math.sqrt((first + second) / (first * second)); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSetUtils.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSetUtils.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSetUtils.scala index b1a1ab6..2663754 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSetUtils.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSetUtils.scala @@ -19,7 +19,7 @@ package org.apache.flink.api.scala import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.{utils => jutils} +import org.apache.flink.api.java.{Utils, utils => jutils} import _root_.scala.language.implicitConversions import _root_.scala.reflect.ClassTag @@ -53,6 +53,44 @@ class DataSetUtils[T](val self: DataSet[T]) extends AnyVal { wrap(jutils.DataSetUtils.zipWithUniqueId(self.javaSet)) .map { t => (t.f0.toLong, t.f1) } } + + // -------------------------------------------------------------------------------------------- + // Sample + // -------------------------------------------------------------------------------------------- + /** + * Generate a sample of DataSet by the probability fraction of each element. + * + * @param withReplacement Whether element can be selected more than once. + * @param fraction Probability that each element is chosen, should be [0,1] without + * replacement, and [0, â) with replacement. While fraction is larger + * than 1, the elements are expected to be selected multi times into + * sample on average. + * @param seed Random number generator seed. + * @return The sampled DataSet + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long = Utils.RNG.nextLong()) + (implicit ti: TypeInformation[T], ct: ClassTag[T]): DataSet[T] = { + + wrap(jutils.DataSetUtils.sample(self.javaSet, withReplacement, fraction, seed)) + } + + /** + * Generate a sample of DataSet with fixed sample size. + * <p> + * <strong>NOTE:</strong> Sample with fixed size is not as efficient as sample with fraction, + * use sample with fraction unless you need exact precision. + * <p/> + * + * @param withReplacement Whether element can be selected more than once. + * @param numSample The expected sample size. + * @param seed Random number generator seed. + * @return The sampled DataSet + */ + def sampleWithSize(withReplacement: Boolean, numSample: Int, seed: Long = Utils.RNG.nextLong()) + (implicit ti: TypeInformation[T], ct: ClassTag[T]): DataSet[T] = { + + wrap(jutils.DataSetUtils.sampleWithSize(self.javaSet, withReplacement, numSample, seed)) + } } object DataSetUtils { http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java ---------------------------------------------------------------------- diff --git a/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java b/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java index c28347c..ce02267 100644 --- a/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java +++ b/flink-test-utils/src/main/java/org/apache/flink/test/util/TestBaseUtils.java @@ -19,11 +19,14 @@ package org.apache.flink.test.util; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + import akka.actor.ActorRef; import akka.dispatch.Futures; import akka.pattern.Patterns; import akka.util.Timeout; +import com.google.common.collect.Lists; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.flink.api.java.tuple.Tuple; @@ -451,6 +454,34 @@ public class TestBaseUtils extends TestLogger { assertEquals(extectedStrings[i], resultStrings[i]); } } + + // -------------------------------------------------------------------------------------------- + // Comparison methods for tests using sample + // -------------------------------------------------------------------------------------------- + + /** + * The expected string contains all expected results separate with line break, check whether all elements in result + * are contained in the expected string. + * @param result The test result. + * @param expected The expected string value combination. + * @param <T> The result type. + */ + public static <T> void containsResultAsText(List<T> result, String expected) { + String[] expectedStrings = expected.split("\n"); + List<String> resultStrings = Lists.newLinkedList(); + + for (int i = 0; i < result.size(); i++) { + T val = result.get(i); + String str = (val == null) ? "null" : val.toString(); + resultStrings.add(str); + } + + List<String> expectedStringList = Arrays.asList(expectedStrings); + + for (String element : resultStrings) { + assertTrue(expectedStringList.contains(element)); + } + } // -------------------------------------------------------------------------------------------- // Miscellaneous helper methods http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SampleITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SampleITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SampleITCase.java new file mode 100644 index 0000000..a9c75e5 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SampleITCase.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.test.javaApiOperators; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.operators.FlatMapOperator; +import org.apache.flink.api.java.operators.MapPartitionOperator; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.utils.DataSetUtils; +import org.apache.flink.test.javaApiOperators.util.CollectionDataSets; +import org.apache.flink.test.util.MultipleProgramsTestBase; +import org.apache.flink.util.Collector; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.List; +import java.util.Random; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("serial") +@RunWith(Parameterized.class) +public class SampleITCase extends MultipleProgramsTestBase { + + private static final Random RNG = new Random(); + + public SampleITCase(TestExecutionMode mode) { + super(mode); + } + + @Before + public void initiate() { + ExecutionEnvironment.getExecutionEnvironment().setParallelism(5); + } + + @Test + public void testSamplerWithFractionWithoutReplacement() throws Exception { + verifySamplerWithFractionWithoutReplacement(0d); + verifySamplerWithFractionWithoutReplacement(0.2d); + verifySamplerWithFractionWithoutReplacement(1.0d); + } + + @Test + public void testSamplerWithFractionWithReplacement() throws Exception { + verifySamplerWithFractionWithReplacement(0d); + verifySamplerWithFractionWithReplacement(0.2d); + verifySamplerWithFractionWithReplacement(1.0d); + verifySamplerWithFractionWithReplacement(2.0d); + } + + @Test + public void testSamplerWithSizeWithoutReplacement() throws Exception { + verifySamplerWithFixedSizeWithoutReplacement(0); + verifySamplerWithFixedSizeWithoutReplacement(2); + verifySamplerWithFixedSizeWithoutReplacement(21); + } + + @Test + public void testSamplerWithSizeWithReplacement() throws Exception { + verifySamplerWithFixedSizeWithReplacement(0); + verifySamplerWithFixedSizeWithReplacement(2); + verifySamplerWithFixedSizeWithReplacement(21); + } + + private void verifySamplerWithFractionWithoutReplacement(double fraction) throws Exception { + verifySamplerWithFractionWithoutReplacement(fraction, RNG.nextLong()); + } + + private void verifySamplerWithFractionWithoutReplacement(double fraction, long seed) throws Exception { + verifySamplerWithFraction(false, fraction, seed); + } + + private void verifySamplerWithFractionWithReplacement(double fraction) throws Exception { + verifySamplerWithFractionWithReplacement(fraction, RNG.nextLong()); + } + + private void verifySamplerWithFractionWithReplacement(double fraction, long seed) throws Exception { + verifySamplerWithFraction(true, fraction, seed); + } + + private void verifySamplerWithFraction(boolean withReplacement, double fraction, long seed) throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + FlatMapOperator<Tuple3<Integer, Long, String>, String> ds = getSourceDataSet(env); + MapPartitionOperator<String, String> sampled = DataSetUtils.sample(ds, withReplacement, fraction, seed); + List<String> result = sampled.collect(); + containsResultAsText(result, getSourceStrings()); + } + + private void verifySamplerWithFixedSizeWithoutReplacement(int numSamples) throws Exception { + verifySamplerWithFixedSizeWithoutReplacement(numSamples, RNG.nextLong()); + } + + private void verifySamplerWithFixedSizeWithoutReplacement(int numSamples, long seed) throws Exception { + verifySamplerWithFixedSize(false, numSamples, seed); + } + + private void verifySamplerWithFixedSizeWithReplacement(int numSamples) throws Exception { + verifySamplerWithFixedSizeWithReplacement(numSamples, RNG.nextLong()); + } + + private void verifySamplerWithFixedSizeWithReplacement(int numSamples, long seed) throws Exception { + verifySamplerWithFixedSize(true, numSamples, seed); + } + + private void verifySamplerWithFixedSize(boolean withReplacement, int numSamples, long seed) throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + FlatMapOperator<Tuple3<Integer, Long, String>, String> ds = getSourceDataSet(env); + DataSet<String> sampled = DataSetUtils.sampleWithSize(ds, withReplacement, numSamples, seed); + List<String> result = sampled.collect(); + assertEquals(numSamples, result.size()); + containsResultAsText(result, getSourceStrings()); + } + + private FlatMapOperator<Tuple3<Integer, Long, String>, String> getSourceDataSet(ExecutionEnvironment env) { + return CollectionDataSets.get3TupleDataSet(env).flatMap( + new FlatMapFunction<Tuple3<Integer, Long, String>, String>() { + @Override + public void flatMap(Tuple3<Integer, Long, String> value, Collector<String> out) throws Exception { + out.collect(value.f2); + } + }); + } + + private String getSourceStrings() { + return "Hi\n" + + "Hello\n" + + "Hello world\n" + + "Hello world, how are you?\n" + + "I am fine.\n" + + "Luke Skywalker\n" + + "Comment#1\n" + + "Comment#2\n" + + "Comment#3\n" + + "Comment#4\n" + + "Comment#5\n" + + "Comment#6\n" + + "Comment#7\n" + + "Comment#8\n" + + "Comment#9\n" + + "Comment#10\n" + + "Comment#11\n" + + "Comment#12\n" + + "Comment#13\n" + + "Comment#14\n" + + "Comment#15\n"; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SampleITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SampleITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SampleITCase.scala new file mode 100644 index 0000000..86b0818 --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SampleITCase.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import java.util.{List => JavaList, Random} + +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils} +import org.junit.Assert._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.{Before, After, Test} + +import org.apache.flink.api.scala.DataSetUtils.utilsToDataSet +import scala.collection.JavaConverters._ + +@RunWith(classOf[Parameterized]) +class SampleITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) { + private val RNG: Random = new Random + + private var result: JavaList[String] = null; + + @Before + def initiate { + ExecutionEnvironment.getExecutionEnvironment.setParallelism(5) + } + + @After + def after() = { + TestBaseUtils.containsResultAsText(result, getSourceStrings) + } + + @Test + @throws(classOf[Exception]) + def testSamplerWithFractionWithoutReplacement { + verifySamplerWithFractionWithoutReplacement(0d) + verifySamplerWithFractionWithoutReplacement(0.2d) + verifySamplerWithFractionWithoutReplacement(1.0d) + } + + @Test + @throws(classOf[Exception]) + def testSamplerWithFractionWithReplacement { + verifySamplerWithFractionWithReplacement(0d) + verifySamplerWithFractionWithReplacement(0.2d) + verifySamplerWithFractionWithReplacement(1.0d) + verifySamplerWithFractionWithReplacement(2.0d) + } + + @Test + @throws(classOf[Exception]) + def testSamplerWithSizeWithoutReplacement { + verifySamplerWithFixedSizeWithoutReplacement(0) + verifySamplerWithFixedSizeWithoutReplacement(2) + verifySamplerWithFixedSizeWithoutReplacement(21) + } + + @Test + @throws(classOf[Exception]) + def testSamplerWithSizeWithReplacement { + verifySamplerWithFixedSizeWithReplacement(0) + verifySamplerWithFixedSizeWithReplacement(2) + verifySamplerWithFixedSizeWithReplacement(21) + } + + @throws(classOf[Exception]) + private def verifySamplerWithFractionWithoutReplacement(fraction: Double) { + verifySamplerWithFractionWithoutReplacement(fraction, RNG.nextLong) + } + + @throws(classOf[Exception]) + private def verifySamplerWithFractionWithoutReplacement(fraction: Double, seed: Long) { + verifySamplerWithFraction(false, fraction, seed) + } + + @throws(classOf[Exception]) + private def verifySamplerWithFractionWithReplacement(fraction: Double) { + verifySamplerWithFractionWithReplacement(fraction, RNG.nextLong) + } + + @throws(classOf[Exception]) + private def verifySamplerWithFractionWithReplacement(fraction: Double, seed: Long) { + verifySamplerWithFraction(true, fraction, seed) + } + + @throws(classOf[Exception]) + private def verifySamplerWithFraction(withReplacement: Boolean, fraction: Double, seed: Long) { + val ds = getSourceDataSet() + val sampled = ds.sample(withReplacement, fraction, seed) + result = sampled.collect.asJava + } + + @throws(classOf[Exception]) + private def verifySamplerWithFixedSizeWithoutReplacement(numSamples: Int) { + verifySamplerWithFixedSizeWithoutReplacement(numSamples, RNG.nextLong) + } + + @throws(classOf[Exception]) + private def verifySamplerWithFixedSizeWithoutReplacement(numSamples: Int, seed: Long) { + verifySamplerWithFixedSize(false, numSamples, seed) + } + + @throws(classOf[Exception]) + private def verifySamplerWithFixedSizeWithReplacement(numSamples: Int) { + verifySamplerWithFixedSizeWithReplacement(numSamples, RNG.nextLong) + } + + @throws(classOf[Exception]) + private def verifySamplerWithFixedSizeWithReplacement(numSamples: Int, seed: Long) { + verifySamplerWithFixedSize(true, numSamples, seed) + } + + @throws(classOf[Exception]) + private def verifySamplerWithFixedSize(withReplacement: Boolean, numSamples: Int, seed: Long) { + val ds = getSourceDataSet() + val sampled = ds.sampleWithSize(withReplacement, numSamples, seed) + result = sampled.collect.asJava + assertEquals(numSamples, result.size) + } + + private def getSourceDataSet(): DataSet[String] = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDataSet = CollectionDataSets.get3TupleDataSet(env) + tupleDataSet.map(x => x._3) + } + + private def getSourceStrings: String = { + return "Hi\n" + + "Hello\n" + + "Hello world\n" + + "Hello world, how are you?\n" + + "I am fine.\n" + + "Luke Skywalker\n" + + "Comment#1\n" + + "Comment#2\n" + + "Comment#3\n" + + "Comment#4\n" + + "Comment#5\n" + + "Comment#6\n" + + "Comment#7\n" + + "Comment#8\n" + + "Comment#9\n" + + "Comment#10\n" + + "Comment#11\n" + + "Comment#12\n" + + "Comment#13\n" + + "Comment#14\n" + + "Comment#15\n" + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/c9cfb17c/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index f215fe4..6af0355 100644 --- a/pom.xml +++ b/pom.xml @@ -224,6 +224,12 @@ under the License. <version>3.2.1</version> </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-math3</artifactId> + <version>3.5</version> + </dependency> + <!-- Managed dependency required for HBase in flink-hbase --> <dependency> <groupId>org.javassist</groupId>