This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 3a9cabf [FLINK-29593] Add QuantileSummary to help calculate
approximate quantiles
3a9cabf is described below
commit 3a9cabf3910e37a8db4db6b3702d5b521e815a10
Author: JiangXin <[email protected]>
AuthorDate: Mon Oct 31 14:11:27 2022 +0800
[FLINK-29593] Add QuantileSummary to help calculate approximate quantiles
This closes #162.
---
.../flink/ml/common/broadcast/BroadcastUtils.java | 3 -
.../flink/ml/common/util/QuantileSummary.java | 414 +++++++++++++++++++++
.../flink/ml/common/util/QuantileSummaryTest.java | 172 +++++++++
3 files changed, 586 insertions(+), 3 deletions(-)
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
index 7315e4e..b6c7f7c 100644
---
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
@@ -18,7 +18,6 @@
package org.apache.flink.ml.common.broadcast;
-import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
import
org.apache.flink.ml.common.broadcast.operator.BroadcastVariableReceiverOperatorFactory;
@@ -40,7 +39,6 @@ import java.util.UUID;
import java.util.function.Function;
/** Utility class to support withBroadcast in DataStream. */
-@Internal
public class BroadcastUtils {
/**
* supports withBroadcastStream in DataStream API. Broadcast data streams
are available at all
@@ -63,7 +61,6 @@ public class BroadcastUtils {
* operator in this function, otherwise it raises an exception.
* @return the output data stream.
*/
- @Internal
public static <OUT> DataStream<OUT> withBroadcastStream(
List<DataStream<?>> inputList,
Map<String, DataStream<?>> bcStreams,
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/QuantileSummary.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/QuantileSummary.java
new file mode 100644
index 0000000..7ae1578
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/QuantileSummary.java
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.util;
+
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * Helper class to compute an approximate quantile summary. This
implementation is based on the
+ * algorithm proposed in the paper: "Space-efficient Online Computation of
Quantile Summaries" by
+ * Greenwald, Michael and Khanna, Sanjeev.
(https://doi.org/10.1145/375663.375670)
+ */
+public class QuantileSummary implements Serializable {
+
+ /** The default size of head buffer. */
+ private static final int DEFAULT_HEAD_SIZE = 50000;
+
+ /** The default compression threshold. */
+ private static final int DEFAULT_COMPRESS_THRESHOLD = 10000;
+
+ /** The target relative error. */
+ private final double relativeError;
+
+ /**
+ * The compression threshold. After the internal buffer of statistics
crosses this size, it
+ * attempts to compress the statistics together.
+ */
+ private final int compressThreshold;
+
+ /** The count of all the elements inserted to be calculated. */
+ private final long count;
+
+ /** A buffer of quantile statistics. */
+ private final List<StatsTuple> sampled;
+
+ /** A buffer of the latest samples seen so far. */
+ private final List<Double> headBuffer = new ArrayList<>();
+
+ /** Whether the quantile summary has been compressed. */
+ private boolean compressed;
+
+ /**
+ * QuantileSummary Constructor.
+ *
+ * @param relativeError The target relative error.
+ */
+ public QuantileSummary(double relativeError) {
+ this(relativeError, DEFAULT_COMPRESS_THRESHOLD);
+ }
+
+ /**
+ * QuantileSummary Constructor.
+ *
+ * @param relativeError The target relative error.
+ * @param compressThreshold the compression threshold. After the internal
buffer of statistics
+ * crosses this size, it attempts to compress the statistics together.
+ */
+ @SuppressWarnings("unchecked")
+ public QuantileSummary(double relativeError, int compressThreshold) {
+ this(relativeError, compressThreshold, Collections.EMPTY_LIST, 0,
false);
+ }
+
+ /**
+ * QuantileSummary Constructor.
+ *
+ * @param relativeError The target relative error.
+ * @param compressThreshold the compression threshold.
+ * @param sampled A buffer of quantile statistics. See the G-K article for
more details.
+ * @param count The count of all the elements inserted in the sampled
buffer.
+ * @param compressed Whether the statistics have been compressed.
+ */
+ private QuantileSummary(
+ double relativeError,
+ int compressThreshold,
+ List<StatsTuple> sampled,
+ long count,
+ boolean compressed) {
+ Preconditions.checkArgument(
+ relativeError > 0 && relativeError < 1,
+ "An appropriate relative error must lay between 0 and 1.");
+ Preconditions.checkArgument(
+ compressThreshold > 0, "An compress threshold must greater
than 0.");
+ this.relativeError = relativeError;
+ this.compressThreshold = compressThreshold;
+ this.sampled = sampled;
+ this.count = count;
+ this.compressed = compressed;
+ }
+
+ /**
+ * Insert a new observation into the summary.
+ *
+ * @param item The new observation to insert into the summary.
+ * @return A summary with the given observation inserted into the summary.
+ */
+ public QuantileSummary insert(double item) {
+ headBuffer.add(item);
+ compressed = false;
+ if (headBuffer.size() >= DEFAULT_HEAD_SIZE) {
+ QuantileSummary result = insertHeadBuffer();
+ if (result.sampled.size() >= compressThreshold) {
+ return result.compress();
+ } else {
+ return result;
+ }
+ } else {
+ return this;
+ }
+ }
+
+ /**
+ * Returns a new summary that compresses the summary statistics and the
head buffer.
+ *
+ * <p>This implements the COMPRESS function of the GK algorithm.
+ *
+ * @return The compressed summary.
+ */
+ public QuantileSummary compress() {
+ if (compressed) {
+ return this;
+ }
+ QuantileSummary inserted = insertHeadBuffer();
+ Preconditions.checkState(inserted.headBuffer.isEmpty());
+ Preconditions.checkState(inserted.count == count + headBuffer.size());
+
+ List<StatsTuple> compressed =
+ compressInternal(inserted.sampled, 2 * relativeError *
inserted.count);
+ return new QuantileSummary(
+ relativeError, compressThreshold, compressed, inserted.count,
true);
+ }
+
+ /**
+ * Merges two summaries together.
+ *
+ * @param other The summary to be merged.
+ * @return The merged summary.
+ */
+ public QuantileSummary merge(QuantileSummary other) {
+ Preconditions.checkState(
+ headBuffer.isEmpty(), "Current buffer needs to be compressed
before merge.");
+ Preconditions.checkState(
+ other.headBuffer.isEmpty(), "Other buffer needs to be
compressed before merge.");
+
+ if (other.count == 0) {
+ return shallowCopy();
+ } else if (count == 0) {
+ return other.shallowCopy();
+ } else {
+ List<StatsTuple> mergedSampled = new ArrayList<>();
+ double mergedRelativeError = Math.max(relativeError,
other.relativeError);
+ long mergedCount = count + other.count;
+ long additionalSelfDelta =
+ Double.valueOf(Math.floor(2 * other.relativeError *
other.count)).longValue();
+ long additionalOtherDelta =
+ Double.valueOf(Math.floor(2 * relativeError *
count)).longValue();
+
+ int selfIdx = 0;
+ int otherIdx = 0;
+ while (selfIdx < sampled.size() && otherIdx <
other.sampled.size()) {
+ StatsTuple selfSample = sampled.get(selfIdx);
+ StatsTuple otherSample = other.sampled.get(otherIdx);
+ StatsTuple nextSample;
+ long additionalDelta = 0;
+ if (selfSample.value < otherSample.value) {
+ nextSample = selfSample;
+ if (otherIdx > 0) {
+ additionalDelta = additionalSelfDelta;
+ }
+ selfIdx++;
+ } else {
+ nextSample = otherSample;
+ if (selfIdx > 0) {
+ additionalDelta = additionalOtherDelta;
+ }
+ otherIdx++;
+ }
+ nextSample = nextSample.shallowCopy();
+ nextSample.delta = nextSample.delta + additionalDelta;
+ mergedSampled.add(nextSample);
+ }
+ IntStream.range(selfIdx, sampled.size())
+ .forEach(i -> mergedSampled.add(sampled.get(i)));
+ IntStream.range(otherIdx, other.sampled.size())
+ .forEach(i -> mergedSampled.add(other.sampled.get(i)));
+
+ List<StatsTuple> comp =
+ compressInternal(mergedSampled, 2 * mergedRelativeError *
mergedCount);
+ return new QuantileSummary(
+ mergedRelativeError, compressThreshold, comp, mergedCount,
true);
+ }
+ }
+
+ /**
+ * Runs a query for a given percentile. The query can only be run on a
compressed summary, you
+ * need to call compress() before using it.
+ *
+ * @param percentile The target percentile.
+ * @return The corresponding approximate quantile.
+ */
+ public double query(double percentile) {
+ return query(new double[] {percentile})[0];
+ }
+
+ /**
+ * Runs a query for a given sequence of percentiles. The query can only be
run on a compressed
+ * summary, you need to call compress() before using it.
+ *
+ * @param percentiles A list of the target percentiles.
+ * @return A list of the corresponding approximate quantiles, in the same
order as the input.
+ */
+ public double[] query(double[] percentiles) {
+ Arrays.stream(percentiles)
+ .forEach(
+ x ->
+ Preconditions.checkState(
+ x >= 0 && x <= 1.0,
+ "percentile should be in the range
[0.0, 1.0]."));
+ Preconditions.checkState(
+ headBuffer.isEmpty(),
+ "Cannot operate on an uncompressed summary, call compress()
first.");
+ Preconditions.checkState(
+ sampled != null && !sampled.isEmpty(),
+ "Cannot query percentiles without any records inserted.");
+ double targetError = Long.MIN_VALUE;
+ for (StatsTuple tuple : sampled) {
+ targetError = Math.max(targetError, (tuple.delta + tuple.g));
+ }
+ targetError = targetError / 2;
+ Map<Double, Integer> zipWithIndex = new HashMap<>(percentiles.length);
+ IntStream.range(0, percentiles.length).forEach(i ->
zipWithIndex.put(percentiles[i], i));
+
+ int index = 0;
+ long minRank = sampled.get(0).g;
+ double[] sorted = Arrays.stream(percentiles).sorted().toArray();
+ double[] result = new double[percentiles.length];
+
+ for (double item : sorted) {
+ int percentileIndex = zipWithIndex.get(item);
+ if (item <= relativeError) {
+ result[percentileIndex] = sampled.get(0).value;
+ } else if (item >= 1 - relativeError) {
+ result[percentileIndex] = sampled.get(sampled.size() -
1).value;
+ } else {
+ QueryResult queryResult =
+ findApproximateQuantile(index, minRank, targetError,
item);
+ index = queryResult.index;
+ minRank = queryResult.minRankAtIndex;
+ result[percentileIndex] = queryResult.percentile;
+ }
+ }
+ return result;
+ }
+
+ /**
+ * Checks whether the QuantileSummary has inserted rows. Running query on
an empty
+ * QuantileSummary would cause {@link java.lang.IllegalStateException}.
+ *
+ * @return True if the QuantileSummary is empty, otherwise false.
+ */
+ public boolean isEmpty() {
+ return headBuffer.isEmpty() && sampled.isEmpty();
+ }
+
+ private QuantileSummary insertHeadBuffer() {
+ if (headBuffer.isEmpty()) {
+ return this;
+ }
+
+ long newCount = count;
+ List<StatsTuple> newSamples = new ArrayList<>();
+ List<Double> sorted =
headBuffer.stream().sorted().collect(Collectors.toList());
+
+ int cursor = 0;
+ for (int i = 0; i < sorted.size(); i++) {
+ while (cursor < sampled.size() && sampled.get(cursor).value <=
sorted.get(i)) {
+ newSamples.add(sampled.get(cursor));
+ cursor++;
+ }
+
+ long delta = Double.valueOf(Math.floor(2.0 * relativeError *
count)).longValue();
+ if (newSamples.isEmpty() || (cursor == sampled.size() && i ==
sorted.size() - 1)) {
+ delta = 0;
+ }
+ StatsTuple tuple = new StatsTuple(sorted.get(i), 1L, delta);
+ newSamples.add(tuple);
+ newCount++;
+ }
+
+ for (int i = cursor; i < sampled.size(); i++) {
+ newSamples.add(sampled.get(i));
+ }
+ return new QuantileSummary(relativeError, compressThreshold,
newSamples, newCount, false);
+ }
+
+ private List<StatsTuple> compressInternal(
+ List<StatsTuple> currentSamples, double mergeThreshold) {
+ if (currentSamples.isEmpty()) {
+ return Collections.emptyList();
+ }
+ LinkedList<StatsTuple> result = new LinkedList<>();
+
+ StatsTuple head = currentSamples.get(currentSamples.size() - 1);
+ for (int i = currentSamples.size() - 2; i >= 1; i--) {
+ StatsTuple tuple = currentSamples.get(i);
+ if (tuple.g + head.g + head.delta < mergeThreshold) {
+ head = head.shallowCopy();
+ head.g = head.g + tuple.g;
+ } else {
+ result.addFirst(head);
+ head = tuple;
+ }
+ }
+ result.addFirst(head);
+
+ StatsTuple currHead = currentSamples.get(0);
+ if (currHead.value <= head.value && currentSamples.size() > 1) {
+ result.addFirst(currHead);
+ }
+ return new ArrayList<>(result);
+ }
+
+ private QueryResult findApproximateQuantile(
+ int index, long minRankAtIndex, double targetError, double
percentile) {
+ StatsTuple curSample = sampled.get(index);
+ long rank = Double.valueOf(Math.ceil(percentile * count)).longValue();
+ long minRank = minRankAtIndex;
+
+ for (int i = index; i < sampled.size() - 1; ) {
+ long maxRank = minRank + curSample.delta;
+ if (maxRank - targetError < rank && rank <= minRank + targetError)
{
+ return new QueryResult(i, minRank, curSample.value);
+ } else {
+ curSample = sampled.get(++i);
+ minRank += curSample.g;
+ }
+ }
+ return new QueryResult(sampled.size() - 1, 0,
sampled.get(sampled.size() - 1).value);
+ }
+
+ public double getRelativeError() {
+ return relativeError;
+ }
+
+ private QuantileSummary shallowCopy() {
+ return new QuantileSummary(relativeError, compressThreshold, sampled,
count, compressed);
+ }
+
+ /** Wrapper class to hold all information returned after querying. */
+ private static class QueryResult {
+ private final int index;
+ private final long minRankAtIndex;
+ private final double percentile;
+
+ public QueryResult(int index, long minRankAtIndex, double percentile) {
+ this.index = index;
+ this.minRankAtIndex = minRankAtIndex;
+ this.percentile = percentile;
+ }
+ }
+
+ /**
+ * Wrapper class to hold all statistics from the Greenwald-Khanna paper.
It contains the
+ * following information:
+ *
+ * <ul>
+ * <li>value: the sampled value.
+ * <li>g: the difference between the least rank of this element and the
rank of the preceding
+ * element.
+ * <li>delta: the maximum span of the rank.
+ * </ul>
+ */
+ private static class StatsTuple implements Serializable {
+ private static final long serialVersionUID = 1L;
+ private final double value;
+ private long g;
+ private long delta;
+
+ public StatsTuple(double value, long g, long delta) {
+ this.value = value;
+ this.g = g;
+ this.delta = delta;
+ }
+
+ public StatsTuple shallowCopy() {
+ return new StatsTuple(value, g, delta);
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/QuantileSummaryTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/QuantileSummaryTest.java
new file mode 100644
index 0000000..b972b5c
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/QuantileSummaryTest.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.util;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.IntStream;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/** Tests {@link QuantileSummary}. */
+public class QuantileSummaryTest {
+
+ private List<double[]> datasets;
+
+ @Before
+ public void prepare() {
+ double[] increasing = IntStream.range(0, 100).mapToDouble(x ->
x).toArray();
+ double[] decreasing = IntStream.range(0, 100).mapToDouble(x -> 99 -
x).toArray();
+ double[] negatives = IntStream.range(-100, 0).mapToDouble(x ->
x).toArray();
+
+ datasets = new ArrayList<>(Arrays.asList(increasing, decreasing,
negatives));
+ }
+
+ private QuantileSummary buildSummary(double[] data, double epsilon) {
+ QuantileSummary summary = new QuantileSummary(epsilon);
+ for (double datum : data) {
+ summary = summary.insert(datum);
+ }
+ return summary.compress();
+ }
+
+ private void checkQuantiles(double[] data, double[] percentiles,
QuantileSummary summary) {
+ if (data.length == 0) {
+ assertNull(summary.query(percentiles));
+ } else {
+ double[] quantiles = summary.query(percentiles);
+ IntStream.range(0, percentiles.length)
+ .forEach(
+ i ->
+ validateApproximation(
+ quantiles[i], data,
percentiles[i], summary));
+ }
+ }
+
+ private void validateApproximation(
+ double approx, double[] data, double percentile, QuantileSummary
summary) {
+ double rank =
+ Math.ceil(
+ (Arrays.stream(data).filter(x -> x <= approx).count()
+ + Arrays.stream(data).filter(x -> x <
approx).count())
+ / 2.0);
+ double lower = Math.floor((percentile - summary.getRelativeError()) *
data.length);
+ double upper = Math.ceil((percentile + summary.getRelativeError()) *
data.length);
+ String errMessage =
+ String.format(
+ "Rank not in [%s, %s], percentile: %s, approx
returned: %s",
+ lower, upper, percentile, approx);
+ assertTrue(errMessage, rank >= lower && rank <= upper);
+ }
+
+ private void checkMergedQuantiles(
+ double[] data1,
+ double epsilon1,
+ double[] data2,
+ double epsilon2,
+ double[] percentiles) {
+ QuantileSummary summary1 = buildSummary(data1, epsilon1);
+ QuantileSummary summary2 = buildSummary(data2, epsilon2);
+ QuantileSummary newSummary = summary2.merge(summary1);
+
+ double[] quantiles = newSummary.query(percentiles);
+ IntStream.range(0, percentiles.length)
+ .forEach(
+ i ->
+ validateApproximation(
+ quantiles[i],
+ ArrayUtils.addAll(data1, data2),
+ percentiles[i],
+ newSummary));
+ }
+
+ @Test
+ public void testQuantiles() {
+ for (double[] data : datasets) {
+ QuantileSummary summary = buildSummary(data, 0.001);
+ double[] percentiles = {0, 0.01, 0.1, 0.25, 0.75, 0.5, 0.9, 0.99,
1};
+ checkQuantiles(data, percentiles, summary);
+ }
+ }
+
+ @Test
+ public void testOnEmptyDataset() {
+ double[] data = new double[0];
+ QuantileSummary summary = buildSummary(data, 0.001);
+ double[] percentiles = {0, 0.01, 0.1, 0.25, 0.75, 0.5, 0.9, 0.99, 1};
+ try {
+ checkQuantiles(data, percentiles, summary);
+ fail();
+ } catch (Throwable e) {
+ assertEquals("Cannot query percentiles without any records
inserted.", e.getMessage());
+ }
+ }
+
+ @Test
+ public void testMerge() {
+ double[] data1 = IntStream.range(0, 100).mapToDouble(x -> x).toArray();
+ double[] data2 = IntStream.range(100, 200).mapToDouble(x ->
x).toArray();
+ double[] data3 = IntStream.range(0, 1000).mapToDouble(x ->
x).toArray();
+ double[] data4 = IntStream.range(-50, 50).mapToDouble(x ->
x).toArray();
+
+ double[] percentiles = {0, 0.01, 0.1, 0.25, 0.75, 0.5, 0.9, 0.99, 1};
+ checkMergedQuantiles(data1, 0.001, data2, 0.001, percentiles);
+ checkMergedQuantiles(data1, 0.0001, data2, 0.0001, percentiles);
+ checkMergedQuantiles(data1, 0.001, data3, 0.001, percentiles);
+ checkMergedQuantiles(data1, 0.001, data4, 0.001, percentiles);
+ }
+
+ @Test
+ public void testQuerySinglePercentile() {
+ QuantileSummary summary = buildSummary(datasets.get(0), 0.001);
+ double approx = summary.query(0.25);
+ validateApproximation(approx, datasets.get(0), 0.25, summary);
+ }
+
+ @Test
+ public void testCompressMultiTimes() {
+ QuantileSummary summary = buildSummary(datasets.get(0), 0.001);
+ QuantileSummary newSummary = summary.compress();
+ assertEquals(summary, newSummary);
+ }
+
+ @Test
+ public void testIsEmpty() {
+ QuantileSummary summary = new QuantileSummary(0.01);
+ assertTrue(summary.isEmpty());
+
+ summary = summary.insert(1);
+ assertFalse(summary.isEmpty());
+
+ summary = summary.compress();
+ assertFalse(summary.isEmpty());
+
+ summary = summary.merge(new QuantileSummary(0.01));
+ assertFalse(summary.isEmpty());
+ }
+}