Repository: incubator-beam Updated Branches: refs/heads/master 96e286fec -> cab0c57c0
Use mutable long[] accumulator for Count CombineFn Previously the accumulator was stored as a Long. This uses a singleton long[] to avoid the boxing and unboxing on every increment. This required changing the Coder (the format actually remains the same, but we have no way of declaring that) so is not backwards compatible with reload. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1792965b Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1792965b Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1792965b Branch: refs/heads/master Commit: 1792965b865472d6f5199388a1700da690a74daa Parents: 17863c8 Author: Robert Bradshaw <[email protected]> Authored: Tue Mar 29 12:28:13 2016 -0700 Committer: Kenneth Knowles <[email protected]> Committed: Thu Mar 31 13:23:40 2016 -0700 ---------------------------------------------------------------------- .../cloud/dataflow/sdk/transforms/Count.java | 81 +++++++++++++++++--- 1 file changed, 69 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1792965b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java index ffa11d1..5ce4d2e 100644 --- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java +++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java @@ -16,10 +16,23 @@ package com.google.cloud.dataflow.sdk.transforms; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.util.VarInt; import com.google.cloud.dataflow.sdk.values.KV; import com.google.cloud.dataflow.sdk.values.PCollection; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; +import java.util.Iterator; + + /** * {@code PTransorm}s to count the elements in a {@link PCollection}. * @@ -106,30 +119,74 @@ public class Count { /** * A {@link CombineFn} that counts elements. */ - private static class CountFn<T> extends CombineFn<T, Long, Long> { + private static class CountFn<T> extends CombineFn<T, long[], Long> { + // Note that the long[] accumulator always has size 1, used as + // a box for a mutable long. @Override - public Long createAccumulator() { - return 0L; + public long[] createAccumulator() { + return new long[] {0}; } @Override - public Long addInput(Long accumulator, T input) { - return accumulator + 1; + public long[] addInput(long[] accumulator, T input) { + accumulator[0] += 1; + return accumulator; } @Override - public Long mergeAccumulators(Iterable<Long> accumulators) { - long result = 0L; - for (Long accum : accumulators) { - result += accum; + public long[] mergeAccumulators(Iterable<long[]> accumulators) { + Iterator<long[]> iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(); } - return result; + long[] running = iter.next(); + while (iter.hasNext()) { + running[0] += iter.next()[0]; + } + return running; } @Override - public Long extractOutput(Long accumulator) { - return accumulator; + public Long extractOutput(long[] accumulator) { + return accumulator[0]; + } + + @Override + public Coder<long[]> getAccumulatorCoder(CoderRegistry registry, + Coder<T> inputCoder) { + return new CustomCoder<long[]>() { + @Override + public void encode(long[] value, OutputStream outStream, Context context) + throws IOException { + VarInt.encode(value[0], outStream); + } + + @Override + public long[] decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return new long[] {VarInt.decodeLong(inStream)}; + } catch (EOFException | UTFDataFormatException exn) { + throw new CoderException(exn); + } + } + + @Override + public boolean isRegisterByteSizeObserverCheap(long[] value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(long[] value, Context context) { + return VarInt.getLength(value[0]); + } + + @Override + public String getEncodingId() { + return "VarLongSingletonArray"; + } + }; } } }
