http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java new file mode 100644 index 0000000..6cf46e5 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -0,0 +1,715 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.util.state.*; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.common.base.Preconditions; +import com.google.protobuf.ByteString; +import org.apache.flink.util.InstantiationUtil; +import org.joda.time.Instant; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.*; + +/** + * An implementation of the Beam {@link StateInternals}. This implementation simply keeps elements in memory. + * This state is periodically checkpointed by Flink, for fault-tolerance. + * + * TODO: State should be rewritten to redirect to Flink per-key state so that coders and combiners don't need + * to be serialized along with encoded values when snapshotting. + */ +public class FlinkStateInternals<K> implements StateInternals<K> { + + private final K key; + + private final Coder<K> keyCoder; + + private final Coder<? extends BoundedWindow> windowCoder; + + private final OutputTimeFn<? super BoundedWindow> outputTimeFn; + + private Instant watermarkHoldAccessor; + + public FlinkStateInternals(K key, + Coder<K> keyCoder, + Coder<? extends BoundedWindow> windowCoder, + OutputTimeFn<? super BoundedWindow> outputTimeFn) { + this.key = key; + this.keyCoder = keyCoder; + this.windowCoder = windowCoder; + this.outputTimeFn = outputTimeFn; + } + + public Instant getWatermarkHold() { + return watermarkHoldAccessor; + } + + /** + * This is the interface state has to implement in order for it to be fault tolerant when + * executed by the FlinkPipelineRunner. + */ + private interface CheckpointableIF { + + boolean shouldPersist(); + + void persistState(StateCheckpointWriter checkpointBuilder) throws IOException; + } + + protected final StateTable<K> inMemoryState = new StateTable<K>() { + @Override + protected StateTag.StateBinder binderForNamespace(final StateNamespace namespace, final StateContext<?> c) { + return new StateTag.StateBinder<K>() { + + @Override + public <T> ValueState<T> bindValue(StateTag<? super K, ValueState<T>> address, Coder<T> coder) { + return new FlinkInMemoryValue<>(encodeKey(namespace, address), coder); + } + + @Override + public <T> BagState<T> bindBag(StateTag<? super K, BagState<T>> address, Coder<T> elemCoder) { + return new FlinkInMemoryBag<>(encodeKey(namespace, address), elemCoder); + } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindCombiningValue( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); + } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { + return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); + } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + CombineWithContext.KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn) { + return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); + } + + @Override + public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark(StateTag<? super K, WatermarkHoldState<W>> address, OutputTimeFn<? super W> outputTimeFn) { + return new FlinkWatermarkHoldStateImpl<>(encodeKey(namespace, address), outputTimeFn); + } + }; + } + }; + + @Override + public K getKey() { + return key; + } + + @Override + public <StateT extends State> StateT state(StateNamespace namespace, StateTag<? super K, StateT> address) { + return inMemoryState.get(namespace, address, null); + } + + @Override + public <T extends State> T state(StateNamespace namespace, StateTag<? super K, T> address, StateContext<?> c) { + return inMemoryState.get(namespace, address, c); + } + + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + checkpointBuilder.writeInt(getNoOfElements()); + + for (State location : inMemoryState.values()) { + if (!(location instanceof CheckpointableIF)) { + throw new IllegalStateException(String.format( + "%s wasn't created by %s -- unable to persist it", + location.getClass().getSimpleName(), + getClass().getSimpleName())); + } + ((CheckpointableIF) location).persistState(checkpointBuilder); + } + } + + public void restoreState(StateCheckpointReader checkpointReader, ClassLoader loader) + throws IOException, ClassNotFoundException { + + // the number of elements to read. + int noOfElements = checkpointReader.getInt(); + for (int i = 0; i < noOfElements; i++) { + decodeState(checkpointReader, loader); + } + } + + /** + * We remove the first character which encodes the type of the stateTag ('s' for system + * and 'u' for user). For more details check out the source of + * {@link StateTags.StateTagBase#getId()}. + */ + private void decodeState(StateCheckpointReader reader, ClassLoader loader) + throws IOException, ClassNotFoundException { + + StateType stateItemType = StateType.deserialize(reader); + ByteString stateKey = reader.getTag(); + + // first decode the namespace and the tagId... + String[] namespaceAndTag = stateKey.toStringUtf8().split("\\+"); + if (namespaceAndTag.length != 2) { + throw new IllegalArgumentException("Invalid stateKey " + stateKey.toString() + "."); + } + StateNamespace namespace = StateNamespaces.fromString(namespaceAndTag[0], windowCoder); + + // ... decide if it is a system or user stateTag... + char ownerTag = namespaceAndTag[1].charAt(0); + if (ownerTag != 's' && ownerTag != 'u') { + throw new RuntimeException("Invalid StateTag name."); + } + boolean isSystemTag = ownerTag == 's'; + String tagId = namespaceAndTag[1].substring(1); + + // ...then decode the coder (if there is one)... + Coder<?> coder = null; + switch (stateItemType) { + case VALUE: + case LIST: + case ACCUMULATOR: + ByteString coderBytes = reader.getData(); + coder = InstantiationUtil.deserializeObject(coderBytes.toByteArray(), loader); + break; + case WATERMARK: + break; + } + + // ...then decode the combiner function (if there is one)... + CombineWithContext.KeyedCombineFnWithContext<? super K, ?, ?, ?> combineFn = null; + switch (stateItemType) { + case ACCUMULATOR: + ByteString combinerBytes = reader.getData(); + combineFn = InstantiationUtil.deserializeObject(combinerBytes.toByteArray(), loader); + break; + case VALUE: + case LIST: + case WATERMARK: + break; + } + + //... and finally, depending on the type of the state being decoded, + // 1) create the adequate stateTag, + // 2) create the state container, + // 3) restore the actual content. + switch (stateItemType) { + case VALUE: { + StateTag stateTag = StateTags.value(tagId, coder); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + @SuppressWarnings("unchecked") + FlinkInMemoryValue<?> value = (FlinkInMemoryValue<?>) inMemoryState.get(namespace, stateTag, null); + value.restoreState(reader); + break; + } + case WATERMARK: { + @SuppressWarnings("unchecked") + StateTag<Object, WatermarkHoldState<BoundedWindow>> stateTag = StateTags.watermarkStateInternal(tagId, outputTimeFn); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + @SuppressWarnings("unchecked") + FlinkWatermarkHoldStateImpl<?> watermark = (FlinkWatermarkHoldStateImpl<?>) inMemoryState.get(namespace, stateTag, null); + watermark.restoreState(reader); + break; + } + case LIST: { + StateTag stateTag = StateTags.bag(tagId, coder); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + FlinkInMemoryBag<?> bag = (FlinkInMemoryBag<?>) inMemoryState.get(namespace, stateTag, null); + bag.restoreState(reader); + break; + } + case ACCUMULATOR: { + @SuppressWarnings("unchecked") + StateTag<K, AccumulatorCombiningState<?, ?, ?>> stateTag = StateTags.keyedCombiningValueWithContext(tagId, (Coder) coder, combineFn); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + @SuppressWarnings("unchecked") + FlinkInMemoryKeyedCombiningValue<?, ?, ?> combiningValue = + (FlinkInMemoryKeyedCombiningValue<?, ?, ?>) inMemoryState.get(namespace, stateTag, null); + combiningValue.restoreState(reader); + break; + } + default: + throw new RuntimeException("Unknown State Type " + stateItemType + "."); + } + } + + private ByteString encodeKey(StateNamespace namespace, StateTag<? super K, ?> address) { + StringBuilder sb = new StringBuilder(); + try { + namespace.appendTo(sb); + sb.append('+'); + address.appendTo(sb); + } catch (IOException e) { + throw new RuntimeException(e); + } + return ByteString.copyFromUtf8(sb.toString()); + } + + private int getNoOfElements() { + int noOfElements = 0; + for (State state : inMemoryState.values()) { + if (!(state instanceof CheckpointableIF)) { + throw new RuntimeException("State Implementations used by the " + + "Flink Dataflow Runner should implement the CheckpointableIF interface."); + } + + if (((CheckpointableIF) state).shouldPersist()) { + noOfElements++; + } + } + return noOfElements; + } + + private final class FlinkInMemoryValue<T> implements ValueState<T>, CheckpointableIF { + + private final ByteString stateKey; + private final Coder<T> elemCoder; + + private T value = null; + + public FlinkInMemoryValue(ByteString stateKey, Coder<T> elemCoder) { + this.stateKey = stateKey; + this.elemCoder = elemCoder; + } + + @Override + public void clear() { + value = null; + } + + @Override + public void write(T input) { + this.value = input; + } + + @Override + public T read() { + return value; + } + + @Override + public ValueState<T> readLater() { + // Ignore + return this; + } + + @Override + public boolean shouldPersist() { + return value != null; + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (value != null) { + // serialize the coder. + byte[] coder = InstantiationUtil.serializeObject(elemCoder); + + // encode the value into a ByteString + ByteString.Output stream = ByteString.newOutput(); + elemCoder.encode(value, stream, Coder.Context.OUTER); + ByteString data = stream.toByteString(); + + checkpointBuilder.addValueBuilder() + .setTag(stateKey) + .setData(coder) + .setData(data); + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + ByteString valueContent = checkpointReader.getData(); + T outValue = elemCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); + write(outValue); + } + } + + private final class FlinkWatermarkHoldStateImpl<W extends BoundedWindow> + implements WatermarkHoldState<W>, CheckpointableIF { + + private final ByteString stateKey; + + private Instant minimumHold = null; + + private OutputTimeFn<? super W> outputTimeFn; + + public FlinkWatermarkHoldStateImpl(ByteString stateKey, OutputTimeFn<? super W> outputTimeFn) { + this.stateKey = stateKey; + this.outputTimeFn = outputTimeFn; + } + + @Override + public void clear() { + // Even though we're clearing we can't remove this from the in-memory state map, since + // other users may already have a handle on this WatermarkBagInternal. + minimumHold = null; + watermarkHoldAccessor = null; + } + + @Override + public void add(Instant watermarkHold) { + if (minimumHold == null || minimumHold.isAfter(watermarkHold)) { + watermarkHoldAccessor = watermarkHold; + minimumHold = watermarkHold; + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + return minimumHold == null; + } + + @Override + public ReadableState<Boolean> readLater() { + // Ignore + return this; + } + }; + } + + @Override + public OutputTimeFn<? super W> getOutputTimeFn() { + return outputTimeFn; + } + + @Override + public Instant read() { + return minimumHold; + } + + @Override + public WatermarkHoldState<W> readLater() { + // Ignore + return this; + } + + @Override + public String toString() { + return Objects.toString(minimumHold); + } + + @Override + public boolean shouldPersist() { + return minimumHold != null; + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (minimumHold != null) { + checkpointBuilder.addWatermarkHoldsBuilder() + .setTag(stateKey) + .setTimestamp(minimumHold); + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + Instant watermark = checkpointReader.getTimestamp(); + add(watermark); + } + } + + + private static <K, InputT, AccumT, OutputT> CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> withContext( + final Combine.KeyedCombineFn<K, InputT, AccumT, OutputT> combineFn) { + return new CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() { + @Override + public AccumT createAccumulator(K key, CombineWithContext.Context c) { + return combineFn.createAccumulator(key); + } + + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value, CombineWithContext.Context c) { + return combineFn.addInput(key, accumulator, value); + } + + @Override + public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, CombineWithContext.Context c) { + return combineFn.mergeAccumulators(key, accumulators); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, CombineWithContext.Context c) { + return combineFn.extractOutput(key, accumulator); + } + }; + } + + private static <K, InputT, AccumT, OutputT> CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> withKeyAndContext( + final Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + return new CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() { + @Override + public AccumT createAccumulator(K key, CombineWithContext.Context c) { + return combineFn.createAccumulator(); + } + + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value, CombineWithContext.Context c) { + return combineFn.addInput(accumulator, value); + } + + @Override + public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, CombineWithContext.Context c) { + return combineFn.mergeAccumulators(accumulators); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, CombineWithContext.Context c) { + return combineFn.extractOutput(accumulator); + } + }; + } + + private final class FlinkInMemoryKeyedCombiningValue<InputT, AccumT, OutputT> + implements AccumulatorCombiningState<InputT, AccumT, OutputT>, CheckpointableIF { + + private final ByteString stateKey; + private final CombineWithContext.KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn; + private final Coder<AccumT> accumCoder; + private final CombineWithContext.Context context; + + private AccumT accum = null; + private boolean isClear = true; + + private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn, + Coder<AccumT> accumCoder, + final StateContext<?> stateContext) { + this(stateKey, withKeyAndContext(combineFn), accumCoder, stateContext); + } + + + private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, + Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn, + Coder<AccumT> accumCoder, + final StateContext<?> stateContext) { + this(stateKey, withContext(combineFn), accumCoder, stateContext); + } + + private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, + CombineWithContext.KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn, + Coder<AccumT> accumCoder, + final StateContext<?> stateContext) { + Preconditions.checkNotNull(combineFn); + Preconditions.checkNotNull(accumCoder); + + this.stateKey = stateKey; + this.combineFn = combineFn; + this.accumCoder = accumCoder; + this.context = new CombineWithContext.Context() { + @Override + public PipelineOptions getPipelineOptions() { + return stateContext.getPipelineOptions(); + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + return stateContext.sideInput(view); + } + }; + accum = combineFn.createAccumulator(key, context); + } + + @Override + public void clear() { + accum = combineFn.createAccumulator(key, context); + isClear = true; + } + + @Override + public void add(InputT input) { + isClear = false; + accum = combineFn.addInput(key, accum, input, context); + } + + @Override + public AccumT getAccum() { + return accum; + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public ReadableState<Boolean> readLater() { + // Ignore + return this; + } + + @Override + public Boolean read() { + return isClear; + } + }; + } + + @Override + public void addAccum(AccumT accum) { + isClear = false; + this.accum = combineFn.mergeAccumulators(key, Arrays.asList(this.accum, accum), context); + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(key, accumulators, context); + } + + @Override + public OutputT read() { + return combineFn.extractOutput(key, accum, context); + } + + @Override + public AccumulatorCombiningState<InputT, AccumT, OutputT> readLater() { + // Ignore + return this; + } + + @Override + public boolean shouldPersist() { + return !isClear; + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (!isClear) { + // serialize the coder. + byte[] coder = InstantiationUtil.serializeObject(accumCoder); + + // serialize the combiner. + byte[] combiner = InstantiationUtil.serializeObject(combineFn); + + // encode the accumulator into a ByteString + ByteString.Output stream = ByteString.newOutput(); + accumCoder.encode(accum, stream, Coder.Context.OUTER); + ByteString data = stream.toByteString(); + + // put the flag that the next serialized element is an accumulator + checkpointBuilder.addAccumulatorBuilder() + .setTag(stateKey) + .setData(coder) + .setData(combiner) + .setData(data); + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + ByteString valueContent = checkpointReader.getData(); + AccumT accum = this.accumCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); + addAccum(accum); + } + } + + private static final class FlinkInMemoryBag<T> implements BagState<T>, CheckpointableIF { + private final List<T> contents = new ArrayList<>(); + + private final ByteString stateKey; + private final Coder<T> elemCoder; + + public FlinkInMemoryBag(ByteString stateKey, Coder<T> elemCoder) { + this.stateKey = stateKey; + this.elemCoder = elemCoder; + } + + @Override + public void clear() { + contents.clear(); + } + + @Override + public Iterable<T> read() { + return contents; + } + + @Override + public BagState<T> readLater() { + // Ignore + return this; + } + + @Override + public void add(T input) { + contents.add(input); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public ReadableState<Boolean> readLater() { + // Ignore + return this; + } + + @Override + public Boolean read() { + return contents.isEmpty(); + } + }; + } + + @Override + public boolean shouldPersist() { + return !contents.isEmpty(); + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (!contents.isEmpty()) { + // serialize the coder. + byte[] coder = InstantiationUtil.serializeObject(elemCoder); + + checkpointBuilder.addListUpdatesBuilder() + .setTag(stateKey) + .setData(coder) + .writeInt(contents.size()); + + for (T item : contents) { + // encode the element + ByteString.Output stream = ByteString.newOutput(); + elemCoder.encode(item, stream, Coder.Context.OUTER); + ByteString data = stream.toByteString(); + + // add the data to the checkpoint. + checkpointBuilder.setData(data); + } + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + int noOfValues = checkpointReader.getInt(); + for (int j = 0; j < noOfValues; j++) { + ByteString valueContent = checkpointReader.getData(); + T outValue = elemCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); + add(outValue); + } + } + } +}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointReader.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointReader.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointReader.java new file mode 100644 index 0000000..5aadccd --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointReader.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import com.google.protobuf.ByteString; +import org.apache.flink.core.memory.DataInputView; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +public class StateCheckpointReader { + + private final DataInputView input; + + public StateCheckpointReader(DataInputView in) { + this.input = in; + } + + public ByteString getTag() throws IOException { + return ByteString.copyFrom(readRawData()); + } + + public String getTagToString() throws IOException { + return input.readUTF(); + } + + public ByteString getData() throws IOException { + return ByteString.copyFrom(readRawData()); + } + + public int getInt() throws IOException { + validate(); + return input.readInt(); + } + + public byte getByte() throws IOException { + validate(); + return input.readByte(); + } + + public Instant getTimestamp() throws IOException { + validate(); + Long watermarkMillis = input.readLong(); + return new Instant(TimeUnit.MICROSECONDS.toMillis(watermarkMillis)); + } + + public <K> K deserializeKey(CoderTypeSerializer<K> keySerializer) throws IOException { + return deserializeObject(keySerializer); + } + + public <T> T deserializeObject(CoderTypeSerializer<T> objectSerializer) throws IOException { + return objectSerializer.deserialize(input); + } + + ///////// Helper Methods /////// + + private byte[] readRawData() throws IOException { + validate(); + int size = input.readInt(); + + byte[] serData = new byte[size]; + int bytesRead = input.read(serData); + if (bytesRead != size) { + throw new RuntimeException("Error while deserializing checkpoint. Not enough bytes in the input stream."); + } + return serData; + } + + private void validate() { + if (this.input == null) { + throw new RuntimeException("StateBackend not initialized yet."); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointUtils.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointUtils.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointUtils.java new file mode 100644 index 0000000..b2dc33c --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointUtils.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.util.TimeDomain; +import com.google.cloud.dataflow.sdk.util.TimerInternals; +import com.google.cloud.dataflow.sdk.util.state.StateNamespace; +import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class StateCheckpointUtils { + + public static <K> void encodeState(Map<K, FlinkStateInternals<K>> perKeyStateInternals, + StateCheckpointWriter writer, Coder<K> keyCoder) throws IOException { + CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); + + int noOfKeys = perKeyStateInternals.size(); + writer.writeInt(noOfKeys); + for (Map.Entry<K, FlinkStateInternals<K>> keyStatePair : perKeyStateInternals.entrySet()) { + K key = keyStatePair.getKey(); + FlinkStateInternals<K> state = keyStatePair.getValue(); + + // encode the key + writer.serializeKey(key, keySerializer); + + // write the associated state + state.persistState(writer); + } + } + + public static <K> Map<K, FlinkStateInternals<K>> decodeState( + StateCheckpointReader reader, + OutputTimeFn<? super BoundedWindow> outputTimeFn, + Coder<K> keyCoder, + Coder<? extends BoundedWindow> windowCoder, + ClassLoader classLoader) throws IOException, ClassNotFoundException { + + int noOfKeys = reader.getInt(); + Map<K, FlinkStateInternals<K>> perKeyStateInternals = new HashMap<>(noOfKeys); + perKeyStateInternals.clear(); + + CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); + for (int i = 0; i < noOfKeys; i++) { + + // decode the key. + K key = reader.deserializeKey(keySerializer); + + //decode the state associated to the key. + FlinkStateInternals<K> stateForKey = + new FlinkStateInternals<>(key, keyCoder, windowCoder, outputTimeFn); + stateForKey.restoreState(reader, classLoader); + perKeyStateInternals.put(key, stateForKey); + } + return perKeyStateInternals; + } + + ////////////// Encoding/Decoding the Timers //////////////// + + + public static <K> void encodeTimers(Map<K, Set<TimerInternals.TimerData>> allTimers, + StateCheckpointWriter writer, + Coder<K> keyCoder) throws IOException { + CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); + + int noOfKeys = allTimers.size(); + writer.writeInt(noOfKeys); + for (Map.Entry<K, Set<TimerInternals.TimerData>> timersPerKey : allTimers.entrySet()) { + K key = timersPerKey.getKey(); + + // encode the key + writer.serializeKey(key, keySerializer); + + // write the associated timers + Set<TimerInternals.TimerData> timers = timersPerKey.getValue(); + encodeTimerDataForKey(writer, timers); + } + } + + public static <K> Map<K, Set<TimerInternals.TimerData>> decodeTimers( + StateCheckpointReader reader, + Coder<? extends BoundedWindow> windowCoder, + Coder<K> keyCoder) throws IOException { + + int noOfKeys = reader.getInt(); + Map<K, Set<TimerInternals.TimerData>> activeTimers = new HashMap<>(noOfKeys); + activeTimers.clear(); + + CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); + for (int i = 0; i < noOfKeys; i++) { + + // decode the key. + K key = reader.deserializeKey(keySerializer); + + // decode the associated timers. + Set<TimerInternals.TimerData> timers = decodeTimerDataForKey(reader, windowCoder); + activeTimers.put(key, timers); + } + return activeTimers; + } + + private static void encodeTimerDataForKey(StateCheckpointWriter writer, Set<TimerInternals.TimerData> timers) throws IOException { + // encode timers + writer.writeInt(timers.size()); + for (TimerInternals.TimerData timer : timers) { + String stringKey = timer.getNamespace().stringKey(); + + writer.setTag(stringKey); + writer.setTimestamp(timer.getTimestamp()); + writer.writeInt(timer.getDomain().ordinal()); + } + } + + private static Set<TimerInternals.TimerData> decodeTimerDataForKey( + StateCheckpointReader reader, Coder<? extends BoundedWindow> windowCoder) throws IOException { + + // decode the timers: first their number and then the content itself. + int noOfTimers = reader.getInt(); + Set<TimerInternals.TimerData> timers = new HashSet<>(noOfTimers); + for (int i = 0; i < noOfTimers; i++) { + String stringKey = reader.getTagToString(); + Instant instant = reader.getTimestamp(); + TimeDomain domain = TimeDomain.values()[reader.getInt()]; + + StateNamespace namespace = StateNamespaces.fromString(stringKey, windowCoder); + timers.add(TimerInternals.TimerData.of(namespace, instant, domain)); + } + return timers; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointWriter.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointWriter.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointWriter.java new file mode 100644 index 0000000..18e118a --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateCheckpointWriter.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import com.google.protobuf.ByteString; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +public class StateCheckpointWriter { + + private final AbstractStateBackend.CheckpointStateOutputView output; + + public static StateCheckpointWriter create(AbstractStateBackend.CheckpointStateOutputView output) { + return new StateCheckpointWriter(output); + } + + private StateCheckpointWriter(AbstractStateBackend.CheckpointStateOutputView output) { + this.output = output; + } + + ///////// Creating the serialized versions of the different types of state held by dataflow /////// + + public StateCheckpointWriter addValueBuilder() throws IOException { + validate(); + StateType.serialize(StateType.VALUE, this); + return this; + } + + public StateCheckpointWriter addWatermarkHoldsBuilder() throws IOException { + validate(); + StateType.serialize(StateType.WATERMARK, this); + return this; + } + + public StateCheckpointWriter addListUpdatesBuilder() throws IOException { + validate(); + StateType.serialize(StateType.LIST, this); + return this; + } + + public StateCheckpointWriter addAccumulatorBuilder() throws IOException { + validate(); + StateType.serialize(StateType.ACCUMULATOR, this); + return this; + } + + ///////// Setting the tag for a given state element /////// + + public StateCheckpointWriter setTag(ByteString stateKey) throws IOException { + return writeData(stateKey.toByteArray()); + } + + public StateCheckpointWriter setTag(String stateKey) throws IOException { + output.writeUTF(stateKey); + return this; + } + + + public <K> StateCheckpointWriter serializeKey(K key, CoderTypeSerializer<K> keySerializer) throws IOException { + return serializeObject(key, keySerializer); + } + + public <T> StateCheckpointWriter serializeObject(T object, CoderTypeSerializer<T> objectSerializer) throws IOException { + objectSerializer.serialize(object, output); + return this; + } + + ///////// Write the actual serialized data ////////// + + public StateCheckpointWriter setData(ByteString data) throws IOException { + return writeData(data.toByteArray()); + } + + public StateCheckpointWriter setData(byte[] data) throws IOException { + return writeData(data); + } + + public StateCheckpointWriter setTimestamp(Instant timestamp) throws IOException { + validate(); + output.writeLong(TimeUnit.MILLISECONDS.toMicros(timestamp.getMillis())); + return this; + } + + public StateCheckpointWriter writeInt(int number) throws IOException { + validate(); + output.writeInt(number); + return this; + } + + public StateCheckpointWriter writeByte(byte b) throws IOException { + validate(); + output.writeByte(b); + return this; + } + + ///////// Helper Methods /////// + + private StateCheckpointWriter writeData(byte[] data) throws IOException { + validate(); + output.writeInt(data.length); + output.write(data); + return this; + } + + private void validate() { + if (this.output == null) { + throw new RuntimeException("StateBackend not initialized yet."); + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateType.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateType.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateType.java new file mode 100644 index 0000000..5849773 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/StateType.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import java.io.IOException; + +/** + * The available types of state, as provided by the Beam SDK. This class is used for serialization/deserialization + * purposes. + * */ +public enum StateType { + + VALUE(0), + + WATERMARK(1), + + LIST(2), + + ACCUMULATOR(3); + + private final int numVal; + + StateType(int value) { + this.numVal = value; + } + + public static void serialize(StateType type, StateCheckpointWriter output) throws IOException { + if (output == null) { + throw new IllegalArgumentException("Cannot write to a null output."); + } + + if(type.numVal < 0 || type.numVal > 3) { + throw new RuntimeException("Unknown State Type " + type + "."); + } + + output.writeByte((byte) type.numVal); + } + + public static StateType deserialize(StateCheckpointReader input) throws IOException { + if (input == null) { + throw new IllegalArgumentException("Cannot read from a null input."); + } + + int typeInt = (int) input.getByte(); + if(typeInt < 0 || typeInt > 3) { + throw new RuntimeException("Unknown State Type " + typeInt + "."); + } + + StateType resultType = null; + for(StateType st: values()) { + if(st.numVal == typeInt) { + resultType = st; + break; + } + } + return resultType; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/main/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/resources/log4j.properties b/runners/flink/runner/src/main/resources/log4j.properties new file mode 100644 index 0000000..4daaad1 --- /dev/null +++ b/runners/flink/runner/src/main/resources/log4j.properties @@ -0,0 +1,23 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +log4j.rootLogger=INFO,console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{2}: %m%n http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/AvroITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/AvroITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/AvroITCase.java new file mode 100644 index 0000000..3536f87 --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/AvroITCase.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + + +public class AvroITCase extends JavaProgramTestBase { + + protected String resultPath; + protected String tmpPath; + + public AvroITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] { + "Joe red 3", + "Mary blue 4", + "Mark green 1", + "Julia purple 5" + }; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + tmpPath = getTempDirPath("tmp"); + + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + runProgram(tmpPath, resultPath); + } + + private static void runProgram(String tmpPath, String resultPath) { + Pipeline p = FlinkTestPipeline.createForBatch(); + + p + .apply(Create.of( + new User("Joe", 3, "red"), + new User("Mary", 4, "blue"), + new User("Mark", 1, "green"), + new User("Julia", 5, "purple")) + .withCoder(AvroCoder.of(User.class))) + + .apply(AvroIO.Write.to(tmpPath) + .withSchema(User.class)); + + p.run(); + + p = FlinkTestPipeline.createForBatch(); + + p + .apply(AvroIO.Read.from(tmpPath).withSchema(User.class).withoutValidation()) + + .apply(ParDo.of(new DoFn<User, String>() { + @Override + public void processElement(ProcessContext c) throws Exception { + User u = c.element(); + String result = u.getName() + " " + u.getFavoriteColor() + " " + u.getFavoriteNumber(); + c.output(result); + } + })) + + .apply(TextIO.Write.to(resultPath)); + + p.run(); + } + + private static class User { + + private String name; + private int favoriteNumber; + private String favoriteColor; + + public User() {} + + public User(String name, int favoriteNumber, String favoriteColor) { + this.name = name; + this.favoriteNumber = favoriteNumber; + this.favoriteColor = favoriteColor; + } + + public String getName() { + return name; + } + + public String getFavoriteColor() { + return favoriteColor; + } + + public int getFavoriteNumber() { + return favoriteNumber; + } + } + +} + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java new file mode 100644 index 0000000..5ae0e83 --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlattenizeITCase.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +public class FlattenizeITCase extends JavaProgramTestBase { + + private String resultPath; + private String resultPath2; + + private static final String[] words = {"hello", "this", "is", "a", "DataSet!"}; + private static final String[] words2 = {"hello", "this", "is", "another", "DataSet!"}; + private static final String[] words3 = {"hello", "this", "is", "yet", "another", "DataSet!"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + resultPath2 = getTempDirPath("result2"); + } + + @Override + protected void postSubmit() throws Exception { + String join = Joiner.on('\n').join(words); + String join2 = Joiner.on('\n').join(words2); + String join3 = Joiner.on('\n').join(words3); + compareResultsByLinesInMemory(join + "\n" + join2, resultPath); + compareResultsByLinesInMemory(join + "\n" + join2 + "\n" + join3, resultPath2); + } + + + @Override + protected void testProgram() throws Exception { + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection<String> p1 = p.apply(Create.of(words)); + PCollection<String> p2 = p.apply(Create.of(words2)); + + PCollectionList<String> list = PCollectionList.of(p1).and(p2); + + list.apply(Flatten.<String>pCollections()).apply(TextIO.Write.to(resultPath)); + + PCollection<String> p3 = p.apply(Create.of(words3)); + + PCollectionList<String> list2 = list.and(p3); + + list2.apply(Flatten.<String>pCollections()).apply(TextIO.Write.to(resultPath2)); + + p.run(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java new file mode 100644 index 0000000..aadda24 --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/FlinkTestPipeline.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; + +/** + * {@link com.google.cloud.dataflow.sdk.Pipeline} for testing Dataflow programs on the + * {@link org.apache.beam.runners.flink.FlinkPipelineRunner}. + */ +public class FlinkTestPipeline extends Pipeline { + + /** + * Creates and returns a new test pipeline for batch execution. + * + * <p> Use {@link com.google.cloud.dataflow.sdk.testing.DataflowAssert} to add tests, then call + * {@link Pipeline#run} to execute the pipeline and check the tests. + */ + public static FlinkTestPipeline createForBatch() { + return create(false); + } + + /** + * Creates and returns a new test pipeline for streaming execution. + * + * <p> Use {@link com.google.cloud.dataflow.sdk.testing.DataflowAssert} to add tests, then call + * {@link Pipeline#run} to execute the pipeline and check the tests. + * + * @return The Test Pipeline + */ + public static FlinkTestPipeline createForStreaming() { + return create(true); + } + + /** + * Creates and returns a new test pipeline for streaming or batch execution. + * + * <p> Use {@link com.google.cloud.dataflow.sdk.testing.DataflowAssert} to add tests, then call + * {@link Pipeline#run} to execute the pipeline and check the tests. + * + * @param streaming <code>True</code> for streaming mode, <code>False</code> for batch. + * @return The Test Pipeline. + */ + private static FlinkTestPipeline create(boolean streaming) { + FlinkPipelineRunner flinkRunner = FlinkPipelineRunner.createForTest(streaming); + return new FlinkTestPipeline(flinkRunner, flinkRunner.getPipelineOptions()); + } + + private FlinkTestPipeline(PipelineRunner<? extends PipelineResult> runner, + PipelineOptions options) { + super(runner, options); + } +} + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java new file mode 100644 index 0000000..f60056d --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/JoinExamplesITCase.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import org.apache.beam.runners.flink.util.JoinExamples; +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.util.Arrays; +import java.util.List; + + +/** + * Unfortunately we need to copy the code from the Dataflow SDK because it is not public there. + */ +public class JoinExamplesITCase extends JavaProgramTestBase { + + protected String resultPath; + + public JoinExamplesITCase(){ + } + + private static final TableRow row1 = new TableRow() + .set("ActionGeo_CountryCode", "VM").set("SQLDATE", "20141212") + .set("Actor1Name", "BANGKOK").set("SOURCEURL", "http://cnn.com"); + private static final TableRow row2 = new TableRow() + .set("ActionGeo_CountryCode", "VM").set("SQLDATE", "20141212") + .set("Actor1Name", "LAOS").set("SOURCEURL", "http://www.chicagotribune.com"); + private static final TableRow row3 = new TableRow() + .set("ActionGeo_CountryCode", "BE").set("SQLDATE", "20141213") + .set("Actor1Name", "AFGHANISTAN").set("SOURCEURL", "http://cnn.com"); + static final TableRow[] EVENTS = new TableRow[] { + row1, row2, row3 + }; + static final List<TableRow> EVENT_ARRAY = Arrays.asList(EVENTS); + + private static final TableRow cc1 = new TableRow() + .set("FIPSCC", "VM").set("HumanName", "Vietnam"); + private static final TableRow cc2 = new TableRow() + .set("FIPSCC", "BE").set("HumanName", "Belgium"); + static final TableRow[] CCS = new TableRow[] { + cc1, cc2 + }; + static final List<TableRow> CC_ARRAY = Arrays.asList(CCS); + + static final String[] JOINED_EVENTS = new String[] { + "Country code: VM, Country name: Vietnam, Event info: Date: 20141212, Actor1: LAOS, " + + "url: http://www.chicagotribune.com", + "Country code: VM, Country name: Vietnam, Event info: Date: 20141212, Actor1: BANGKOK, " + + "url: http://cnn.com", + "Country code: BE, Country name: Belgium, Event info: Date: 20141213, Actor1: AFGHANISTAN, " + + "url: http://cnn.com" + }; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(JOINED_EVENTS), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection<TableRow> input1 = p.apply(Create.of(EVENT_ARRAY)); + PCollection<TableRow> input2 = p.apply(Create.of(CC_ARRAY)); + + PCollection<String> output = JoinExamples.joinEvents(input1, input2); + + output.apply(TextIO.Write.to(resultPath)); + + p.run(); + } +} + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java new file mode 100644 index 0000000..199602c --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/MaybeEmptyTestITCase.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.io.Serializable; + +public class MaybeEmptyTestITCase extends JavaProgramTestBase implements Serializable { + + protected String resultPath; + + protected final String expected = "test"; + + public MaybeEmptyTestITCase() { + } + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(expected, resultPath); + } + + @Override + protected void testProgram() throws Exception { + + Pipeline p = FlinkTestPipeline.createForBatch(); + + p.apply(Create.of((Void) null)).setCoder(VoidCoder.of()) + .apply(ParDo.of( + new DoFn<Void, String>() { + @Override + public void processElement(DoFn<Void, String>.ProcessContext c) { + c.output(expected); + } + })).apply(TextIO.Write.to(resultPath)); + p.run(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java new file mode 100644 index 0000000..403de29 --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ParDoMultiOutputITCase.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionTuple; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.io.Serializable; + +public class ParDoMultiOutputITCase extends JavaProgramTestBase implements Serializable { + + private String resultPath; + + private static String[] expectedWords = {"MAAA", "MAAFOOO"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on("\n").join(expectedWords), resultPath); + } + + @Override + protected void testProgram() throws Exception { + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection<String> words = p.apply(Create.of("Hello", "Whatupmyman", "hey", "SPECIALthere", "MAAA", "MAAFOOO")); + + // Select words whose length is below a cut off, + // plus the lengths of words that are above the cut off. + // Also select words starting with "MARKER". + final int wordLengthCutOff = 3; + // Create tags to use for the main and side outputs. + final TupleTag<String> wordsBelowCutOffTag = new TupleTag<String>(){}; + final TupleTag<Integer> wordLengthsAboveCutOffTag = new TupleTag<Integer>(){}; + final TupleTag<String> markedWordsTag = new TupleTag<String>(){}; + + PCollectionTuple results = + words.apply(ParDo + .withOutputTags(wordsBelowCutOffTag, TupleTagList.of(wordLengthsAboveCutOffTag) + .and(markedWordsTag)) + .of(new DoFn<String, String>() { + final TupleTag<String> specialWordsTag = new TupleTag<String>() { + }; + + public void processElement(ProcessContext c) { + String word = c.element(); + if (word.length() <= wordLengthCutOff) { + c.output(word); + } else { + c.sideOutput(wordLengthsAboveCutOffTag, word.length()); + } + if (word.startsWith("MAA")) { + c.sideOutput(markedWordsTag, word); + } + + if (word.startsWith("SPECIAL")) { + c.sideOutput(specialWordsTag, word); + } + } + })); + + // Extract the PCollection results, by tag. + PCollection<String> wordsBelowCutOff = results.get(wordsBelowCutOffTag); + PCollection<Integer> wordLengthsAboveCutOff = results.get + (wordLengthsAboveCutOffTag); + PCollection<String> markedWords = results.get(markedWordsTag); + + markedWords.apply(TextIO.Write.to(resultPath)); + + p.run(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java new file mode 100644 index 0000000..323c41b --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + + +public class ReadSourceITCase extends JavaProgramTestBase { + + protected String resultPath; + + public ReadSourceITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] { + "1", "2", "3", "4", "5", "6", "7", "8", "9"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + runProgram(resultPath); + } + + private static void runProgram(String resultPath) { + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection<String> result = p + .apply(Read.from(new ReadSource(1, 10))) + .apply(ParDo.of(new DoFn<Integer, String>() { + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(c.element().toString()); + } + })); + + result.apply(TextIO.Write.to(resultPath)); + p.run(); + } + + + private static class ReadSource extends BoundedSource<Integer> { + final int from; + final int to; + + ReadSource(int from, int to) { + this.from = from; + this.to = to; + } + + @Override + public List<ReadSource> splitIntoBundles(long desiredShardSizeBytes, PipelineOptions options) + throws Exception { + List<ReadSource> res = new ArrayList<>(); + FlinkPipelineOptions flinkOptions = options.as(FlinkPipelineOptions.class); + int numWorkers = flinkOptions.getParallelism(); + Preconditions.checkArgument(numWorkers > 0, "Number of workers should be larger than 0."); + + float step = 1.0f * (to - from) / numWorkers; + for (int i = 0; i < numWorkers; ++i) { + res.add(new ReadSource(Math.round(from + i * step), Math.round(from + (i + 1) * step))); + } + return res; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return 8 * (to - from); + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return true; + } + + @Override + public BoundedReader<Integer> createReader(PipelineOptions options) throws IOException { + return new RangeReader(this); + } + + @Override + public void validate() {} + + @Override + public Coder<Integer> getDefaultOutputCoder() { + return BigEndianIntegerCoder.of(); + } + + private class RangeReader extends BoundedReader<Integer> { + private int current; + + public RangeReader(ReadSource source) { + this.current = source.from - 1; + } + + @Override + public boolean start() throws IOException { + return true; + } + + @Override + public boolean advance() throws IOException { + current++; + return (current < to); + } + + @Override + public Integer getCurrent() { + return current; + } + + @Override + public void close() throws IOException { + // Nothing + } + + @Override + public BoundedSource<Integer> getCurrentSource() { + return ReadSource.this; + } + } + } +} + + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java new file mode 100644 index 0000000..524554a --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesEmptyITCase.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.util.Collections; +import java.util.List; + + +public class RemoveDuplicatesEmptyITCase extends JavaProgramTestBase { + + protected String resultPath; + + public RemoveDuplicatesEmptyITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] {}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + List<String> strings = Collections.emptyList(); + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection<String> input = + p.apply(Create.of(strings)) + .setCoder(StringUtf8Coder.of()); + + PCollection<String> output = + input.apply(RemoveDuplicates.<String>create()); + + output.apply(TextIO.Write.to(resultPath)); + p.run(); + } +} + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java new file mode 100644 index 0000000..54e92aa --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/RemoveDuplicatesITCase.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.util.Arrays; +import java.util.List; + + +public class RemoveDuplicatesITCase extends JavaProgramTestBase { + + protected String resultPath; + + public RemoveDuplicatesITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] { + "k1", "k5", "k2", "k3"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + List<String> strings = Arrays.asList("k1", "k5", "k5", "k2", "k1", "k2", "k3"); + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection<String> input = + p.apply(Create.of(strings)) + .setCoder(StringUtf8Coder.of()); + + PCollection<String> output = + input.apply(RemoveDuplicates.<String>create()); + + output.apply(TextIO.Write.to(resultPath)); + p.run(); + } +} + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java new file mode 100644 index 0000000..7f73b83 --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/SideInputITCase.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.io.Serializable; + +public class SideInputITCase extends JavaProgramTestBase implements Serializable { + + private static final String expected = "Hello!"; + + protected String resultPath; + + @Override + protected void testProgram() throws Exception { + + + Pipeline p = FlinkTestPipeline.createForBatch(); + + + final PCollectionView<String> sidesInput = p + .apply(Create.of(expected)) + .apply(View.<String>asSingleton()); + + p.apply(Create.of("bli")) + .apply(ParDo.of(new DoFn<String, String>() { + @Override + public void processElement(ProcessContext c) throws Exception { + String s = c.sideInput(sidesInput); + c.output(s); + } + }).withSideInputs(sidesInput)).apply(TextIO.Write.to(resultPath)); + + p.run(); + } + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(expected, resultPath); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java new file mode 100644 index 0000000..b0fb7b8 --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/TfIdfITCase.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.examples.complete.TfIdf; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringDelegateCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.Keys; +import com.google.cloud.dataflow.sdk.transforms.RemoveDuplicates; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.net.URI; + + +public class TfIdfITCase extends JavaProgramTestBase { + + protected String resultPath; + + public TfIdfITCase(){ + } + + static final String[] EXPECTED_RESULT = new String[] { + "a", "m", "n", "b", "c", "d"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + Pipeline pipeline = FlinkTestPipeline.createForBatch(); + + pipeline.getCoderRegistry().registerCoder(URI.class, StringDelegateCoder.of(URI.class)); + + PCollection<KV<String, KV<URI, Double>>> wordToUriAndTfIdf = pipeline + .apply(Create.of( + KV.of(new URI("x"), "a b c d"), + KV.of(new URI("y"), "a b c"), + KV.of(new URI("z"), "a m n"))) + .apply(new TfIdf.ComputeTfIdf()); + + PCollection<String> words = wordToUriAndTfIdf + .apply(Keys.<String>create()) + .apply(RemoveDuplicates.<String>create()); + + words.apply(TextIO.Write.to(resultPath)); + + pipeline.run(); + } +} + http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071e4dd6/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java new file mode 100644 index 0000000..2677c9e --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/WordCountITCase.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import com.google.cloud.dataflow.examples.WordCount; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.MapElements; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Joiner; +import org.apache.flink.test.util.JavaProgramTestBase; + +import java.util.Arrays; +import java.util.List; + + +public class WordCountITCase extends JavaProgramTestBase { + + protected String resultPath; + + public WordCountITCase(){ + } + + static final String[] WORDS_ARRAY = new String[] { + "hi there", "hi", "hi sue bob", + "hi sue", "", "bob hi"}; + + static final List<String> WORDS = Arrays.asList(WORDS_ARRAY); + + static final String[] COUNTS_ARRAY = new String[] { + "hi: 5", "there: 1", "sue: 2", "bob: 2"}; + + @Override + protected void preSubmit() throws Exception { + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(Joiner.on('\n').join(COUNTS_ARRAY), resultPath); + } + + @Override + protected void testProgram() throws Exception { + + Pipeline p = FlinkTestPipeline.createForBatch(); + + PCollection<String> input = p.apply(Create.of(WORDS)).setCoder(StringUtf8Coder.of()); + + input + .apply(new WordCount.CountWords()) + .apply(MapElements.via(new WordCount.FormatAsTextFn())) + .apply(TextIO.Write.to(resultPath)); + + p.run(); + } +} +
