Add GroupIntoBatches This groups input KVs into output K, Iterable<V>s of a specified size.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/1e9089ff Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/1e9089ff Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/1e9089ff Branch: refs/heads/master Commit: 1e9089ffdf969792d2fae3ecca829a8a4f3f3884 Parents: 498ce9f Author: Etienne Chauchot <[email protected]> Authored: Wed Jan 18 15:04:48 2017 +0100 Committer: Thomas Groh <[email protected]> Committed: Tue Apr 4 09:29:56 2017 -0700 ---------------------------------------------------------------------- .../beam/sdk/transforms/GroupIntoBatches.java | 229 ++++++++++++++++++ .../sdk/transforms/GroupIntoBatchesTest.java | 232 +++++++++++++++++++ 2 files changed, 461 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/1e9089ff/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java new file mode 100644 index 0000000..095ca2a --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Iterables; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.TimeDomain; +import org.apache.beam.sdk.util.Timer; +import org.apache.beam.sdk.util.TimerSpec; +import org.apache.beam.sdk.util.TimerSpecs; +import org.apache.beam.sdk.util.state.AccumulatorCombiningState; +import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.StateSpec; +import org.apache.beam.sdk.util.state.StateSpecs; +import org.apache.beam.sdk.util.state.ValueState; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link PTransform} that batches inputs to a desired batch size. Batches will contain only + * elements of a single key. + * + * <p>Elements are buffered until there are {@code batchSize} elements + * buffered, at which point they are output to the output {@link PCollection}. + * + * <p>Windows are preserved (batches contain elements from the same window). + * Batches may contain elements from more than one bundle + * + * <p>Example (batch call a webservice and get return codes) + * + * <pre>{@code + * Pipeline pipeline = Pipeline.create(...); + * ... // KV collection + * long batchSize = 100L; + * pipeline.apply(GroupIntoBatches.<String, String>ofSize(batchSize)) + * .setCoder(KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of()))) + * .apply(ParDo.of(new DoFn<KV<String, Iterable<String>>, KV<String, String>>() { + * {@literal @}ProcessElement + * public void processElement(ProcessContext c){ + * c.output(KV.of(c.element().getKey(), callWebService(c.element().getValue()))); + * } + * })); + * pipeline.run(); + * }</pre> + */ +public class GroupIntoBatches<K, InputT> + extends PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, Iterable<InputT>>>> { + + private final long batchSize; + + private GroupIntoBatches(long batchSize) { + this.batchSize = batchSize; + } + + public static <K, InputT> GroupIntoBatches<K, InputT> ofSize(long batchSize) { + return new GroupIntoBatches<>(batchSize); + } + + @Override + public PCollection<KV<K, Iterable<InputT>>> expand(PCollection<KV<K, InputT>> input) { + Duration allowedLateness = input.getWindowingStrategy().getAllowedLateness(); + + checkArgument( + input.getCoder() instanceof KvCoder, + "coder specified in the input PCollection is not a KvCoder"); + KvCoder inputCoder = (KvCoder) input.getCoder(); + Coder<K> keyCoder = (Coder<K>) inputCoder.getCoderArguments().get(0); + Coder<InputT> valueCoder = (Coder<InputT>) inputCoder.getCoderArguments().get(1); + + return input.apply( + ParDo.of(new GroupIntoBatchesDoFn<>(batchSize, allowedLateness, keyCoder, valueCoder))); + } + + @VisibleForTesting + static class GroupIntoBatchesDoFn<K, InputT> + extends DoFn<KV<K, InputT>, KV<K, Iterable<InputT>>> { + + private static final Logger LOGGER = LoggerFactory.getLogger(GroupIntoBatchesDoFn.class); + private static final String END_OF_WINDOW_ID = "endOFWindow"; + private static final String BATCH_ID = "batch"; + private static final String NUM_ELEMENTS_IN_BATCH_ID = "numElementsInBatch"; + private static final String KEY_ID = "key"; + private final long batchSize; + private final Duration allowedLateness; + + @TimerId(END_OF_WINDOW_ID) + private final TimerSpec timer = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @StateId(BATCH_ID) + private final StateSpec<Object, BagState<InputT>> batchSpec; + + @StateId(NUM_ELEMENTS_IN_BATCH_ID) + private final StateSpec<Object, AccumulatorCombiningState<Long, Long, Long>> + numElementsInBatchSpec; + + @StateId(KEY_ID) + private final StateSpec<Object, ValueState<K>> keySpec; + + private final long prefetchFrequency; + + GroupIntoBatchesDoFn( + long batchSize, + Duration allowedLateness, + Coder<K> inputKeyCoder, + Coder<InputT> inputValueCoder) { + this.batchSize = batchSize; + this.allowedLateness = allowedLateness; + this.batchSpec = StateSpecs.bag(inputValueCoder); + this.numElementsInBatchSpec = + StateSpecs.combiningValue( + VarLongCoder.of(), + new Combine.CombineFn<Long, Long, Long>() { + + @Override + public Long createAccumulator() { + return 0L; + } + + @Override + public Long addInput(Long accumulator, Long input) { + return accumulator + input; + } + + @Override + public Long mergeAccumulators(Iterable<Long> accumulators) { + long sum = 0L; + for (Long accumulator : accumulators) { + sum += accumulator; + } + return sum; + } + + @Override + public Long extractOutput(Long accumulator) { + return accumulator; + } + }); + + this.keySpec = StateSpecs.value(inputKeyCoder); + // prefetch every 20% of batchSize elements. Do not prefetch if batchSize is too little + this.prefetchFrequency = ((batchSize / 5) <= 1) ? Long.MAX_VALUE : (batchSize / 5); + } + + @ProcessElement + public void processElement( + @TimerId(END_OF_WINDOW_ID) Timer timer, + @StateId(BATCH_ID) BagState<InputT> batch, + @StateId(NUM_ELEMENTS_IN_BATCH_ID) + AccumulatorCombiningState<Long, Long, Long> numElementsInBatch, + @StateId(KEY_ID) ValueState<K> key, + ProcessContext c, + BoundedWindow window) { + Instant windowExpires = window.maxTimestamp().plus(allowedLateness); + + LOGGER.debug( + "*** SET TIMER *** to point in time {} for window {}", + windowExpires.toString(), window.toString()); + timer.set(windowExpires); + key.write(c.element().getKey()); + batch.add(c.element().getValue()); + LOGGER.debug("*** BATCH *** Add element for window {} ", window.toString()); + // blind add is supported with combiningState + numElementsInBatch.add(1L); + Long num = numElementsInBatch.read(); + if (num % prefetchFrequency == 0) { + //prefetch data and modify batch state (readLater() modifies this) + batch.readLater(); + } + if (num >= batchSize) { + LOGGER.debug("*** END OF BATCH *** for window {}", window.toString()); + flushBatch(c, key, batch, numElementsInBatch); + } + } + + @OnTimer(END_OF_WINDOW_ID) + public void onTimerCallback( + OnTimerContext context, + @StateId(KEY_ID) ValueState<K> key, + @StateId(BATCH_ID) BagState<InputT> batch, + @StateId(NUM_ELEMENTS_IN_BATCH_ID) + AccumulatorCombiningState<Long, Long, Long> numElementsInBatch, + BoundedWindow window) { + LOGGER.debug( + "*** END OF WINDOW *** for timer timestamp {} in windows {}", + context.timestamp(), window.toString()); + flushBatch(context, key, batch, numElementsInBatch); + } + + private void flushBatch( + Context c, + ValueState<K> key, + BagState<InputT> batch, + AccumulatorCombiningState<Long, Long, Long> numElementsInBatch) { + Iterable<InputT> values = batch.read(); + // when the timer fires, batch state might be empty + if (Iterables.size(values) > 0) { + c.output(KV.of(key.read(), values)); + } + batch.clear(); + LOGGER.debug("*** BATCH *** clear"); + numElementsInBatch.clear(); + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/1e9089ff/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java new file mode 100644 index 0000000..54e2d5a --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.Iterables; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Iterator; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.testing.UsesStatefulParDo; +import org.apache.beam.sdk.testing.UsesTestStream; +import org.apache.beam.sdk.testing.UsesTimersInParDo; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Test Class for {@link GroupIntoBatches}. */ +@RunWith(JUnit4.class) +public class GroupIntoBatchesTest implements Serializable { + private static final int BATCH_SIZE = 5; + private static final long NUM_ELEMENTS = 10; + private static final int ALLOWED_LATENESS = 0; + private static final Logger LOGGER = LoggerFactory.getLogger(GroupIntoBatchesTest.class); + @Rule public transient TestPipeline pipeline = TestPipeline.create(); + private transient ArrayList<KV<String, String>> data = createTestData(); + + private static ArrayList<KV<String, String>> createTestData() { + String[] scientists = { + "Einstein", + "Darwin", + "Copernicus", + "Pasteur", + "Curie", + "Faraday", + "Newton", + "Bohr", + "Galilei", + "Maxwell" + }; + ArrayList<KV<String, String>> data = new ArrayList<>(); + for (int i = 0; i < NUM_ELEMENTS; i++) { + int index = i % scientists.length; + KV<String, String> element = KV.of("key", scientists[index]); + data.add(element); + } + return data; + } + + @Test + @Category({NeedsRunner.class, UsesTimersInParDo.class, UsesStatefulParDo.class}) + public void testInGlobalWindow() { + PCollection<KV<String, Iterable<String>>> collection = + pipeline + .apply("Input data", Create.of(data)) + .apply(GroupIntoBatches.<String, String>ofSize(BATCH_SIZE)) + //set output coder + .setCoder(KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of()))); + PAssert.that("Incorrect batch size in one ore more elements", collection) + .satisfies( + new SerializableFunction<Iterable<KV<String, Iterable<String>>>, Void>() { + + private boolean checkBatchSizes(Iterable<KV<String, Iterable<String>>> listToCheck) { + for (KV<String, Iterable<String>> element : listToCheck) { + if (Iterables.size(element.getValue()) != BATCH_SIZE){ + return false; + } + } + return true; + } + + @Override + public Void apply(Iterable<KV<String, Iterable<String>>> input) { + assertTrue(checkBatchSizes(input)); + return null; + } + }); + PAssert.thatSingleton( + "Incorrect collection size", + collection.apply("Count", Count.<KV<String, Iterable<String>>>globally())) + .isEqualTo(NUM_ELEMENTS / BATCH_SIZE); + pipeline.run(); + } + + @Test + @Category({ + NeedsRunner.class, + UsesTimersInParDo.class, + UsesTestStream.class, + UsesStatefulParDo.class + }) + public void testInStreamingMode() { + int timestampInterval = 1; + Instant startInstant = new Instant(0L); + TestStream.Builder<KV<String, String>> streamBuilder = + TestStream.create(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())) + .advanceWatermarkTo(startInstant); + long offset = 0L; + for (KV<String, String> element : data) { + streamBuilder = + streamBuilder.addElements( + TimestampedValue.of( + element, + startInstant.plus(Duration.standardSeconds(offset * timestampInterval)))); + offset++; + } + final long windowDuration = 6; + TestStream<KV<String, String>> stream = + streamBuilder + .advanceWatermarkTo(startInstant.plus(Duration.standardSeconds(windowDuration - 1))) + .advanceWatermarkTo(startInstant.plus(Duration.standardSeconds(windowDuration + 1))) + .advanceWatermarkTo(startInstant.plus(Duration.standardSeconds(NUM_ELEMENTS))) + .advanceWatermarkToInfinity(); + + PCollection<KV<String, String>> inputCollection = + pipeline + .apply(stream) + .apply( + Window.<KV<String, String>>into( + FixedWindows.of(Duration.standardSeconds(windowDuration))) + .withAllowedLateness(Duration.millis(ALLOWED_LATENESS))); + inputCollection.apply( + ParDo.of( + new DoFn<KV<String, String>, Void>() { + @ProcessElement + public void processElement(ProcessContext c, BoundedWindow window) { + LOGGER.debug( + "*** ELEMENT: ({},{}) *** with timestamp %s in window %s", + c.element().getKey(), + c.element().getValue(), + c.timestamp().toString(), + window.toString()); + } + })); + + PCollection<KV<String, Iterable<String>>> outputCollection = + inputCollection + .apply(GroupIntoBatches.<String, String>ofSize(BATCH_SIZE)) + .setCoder(KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of()))); + + // elements have the same key and collection is divided into windows, + // so Count.perKey values are the number of elements in windows + PCollection<KV<String, Long>> countOutput = + outputCollection.apply( + "Count elements in windows after applying GroupIntoBatches", + Count.<String, Iterable<String>>perKey()); + + PAssert.that("Wrong number of elements in windows after GroupIntoBatches", countOutput) + .satisfies( + new SerializableFunction<Iterable<KV<String, Long>>, Void>() { + + @Override + public Void apply(Iterable<KV<String, Long>> input) { + Iterator<KV<String, Long>> inputIterator = input.iterator(); + // first element + long count0 = inputIterator.next().getValue(); + // window duration is 6 and batch size is 5, so there should be 2 elements in the + // window (flush because batchSize reached and for end of window reached) + assertEquals("Wrong number of elements in first window", 2, count0); + // second element + long count1 = inputIterator.next().getValue(); + // collection is 10 elements, there is only 4 elements left, so there should be only + // one element in the window (flush because end of window/collection reached) + assertEquals("Wrong number of elements in second window", 1, count1); + // third element + return null; + } + }); + + PAssert.that("Incorrect output collection after GroupIntoBatches", outputCollection) + .satisfies( + new SerializableFunction<Iterable<KV<String, Iterable<String>>>, Void>() { + + @Override + public Void apply(Iterable<KV<String, Iterable<String>>> input) { + Iterator<KV<String, Iterable<String>>> inputIterator = input.iterator(); + // first element + int size0 = Iterables.size(inputIterator.next().getValue()); + // window duration is 6 and batch size is 5, so output batch size should de 5 + // (flush because of batchSize reached) + assertEquals("Wrong first element batch Size", 5, size0); + // second element + int size1 = Iterables.size(inputIterator.next().getValue()); + // there is only one element left in the window so batch size should be 1 + // (flush because of end of window reached) + assertEquals("Wrong second element batch Size", 1, size1); + // third element + int size2 = Iterables.size(inputIterator.next().getValue()); + // collection is 10 elements, there is only 4 left, so batch size should be 4 + // (flush because end of collection reached) + assertEquals("Wrong third element batch Size", 4, size2); + return null; + } + }); + pipeline.run().waitUntilFinish(); + } +}
